diff --git a/build.sbt b/build.sbt index 95324fc99..915aee126 100644 --- a/build.sbt +++ b/build.sbt @@ -63,6 +63,7 @@ lazy val flintCore = (project in file("flint-core")) exclude ("com.fasterxml.jackson.core", "jackson-databind"), "com.amazonaws" % "aws-java-sdk-cloudwatch" % "1.12.593" exclude("com.fasterxml.jackson.core", "jackson-databind"), + "software.amazon.awssdk" % "auth-crt" % "2.25.23", "org.scalactic" %% "scalactic" % "3.2.15" % "test", "org.scalatest" %% "scalatest" % "3.2.15" % "test", "org.scalatest" %% "scalatest-flatspec" % "3.2.15" % "test", diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/auth/AWSRequestSigV4ASigningApacheInterceptor.java b/flint-core/src/main/scala/org/opensearch/flint/core/auth/AWSRequestSigV4ASigningApacheInterceptor.java new file mode 100644 index 000000000..dc3fd9d41 --- /dev/null +++ b/flint-core/src/main/scala/org/opensearch/flint/core/auth/AWSRequestSigV4ASigningApacheInterceptor.java @@ -0,0 +1,237 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.auth; + +import com.amazonaws.auth.AWSSessionCredentials; +import com.amazonaws.auth.AWSCredentialsProvider; +import com.amazonaws.services.glue.model.InvalidStateException; +import org.apache.http.Header; +import org.apache.http.HttpEntity; +import org.apache.http.HttpEntityEnclosingRequest; +import org.apache.http.HttpHost; +import org.apache.http.HttpRequest; +import org.apache.http.HttpRequestInterceptor; +import org.apache.http.client.utils.URIBuilder; +import org.apache.http.entity.BasicHttpEntity; +import org.apache.http.message.BasicHeader; +import org.apache.http.protocol.HttpContext; +import software.amazon.awssdk.auth.credentials.AwsSessionCredentials; +import software.amazon.awssdk.auth.signer.AwsSignerExecutionAttribute; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.signer.Signer; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.SdkHttpMethod; +import software.amazon.awssdk.regions.Region; + +import java.io.IOException; +import java.io.InputStream; +import java.io.UncheckedIOException; +import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; +import java.util.logging.Level; +import java.util.logging.Logger; + +import static org.apache.http.protocol.HttpCoreContext.HTTP_TARGET_HOST; +import static org.opensearch.flint.core.auth.AWSRequestSigningApacheInterceptor.nvpToMapParams; +import static org.opensearch.flint.core.auth.AWSRequestSigningApacheInterceptor.skipHeader; + +/** + * Interceptor for signing AWS requests according to Signature Version 4A. + * This interceptor processes HTTP requests, signs them with AWS credentials, + * and updates the request headers to include the signature. + */ +public class AWSRequestSigV4ASigningApacheInterceptor implements HttpRequestInterceptor { + private static final Logger LOG = Logger.getLogger(AWSRequestSigV4ASigningApacheInterceptor.class.getName()); + + private static final String HTTPS_PROTOCOL = "https"; + private static final int HTTPS_PORT = 443; + + private final String service; + private final String region; + private final Signer signer; + private final AWSCredentialsProvider awsCredentialsProvider; + + /** + * Constructs an interceptor for AWS request signing with metadata access. + * + * @param service The AWS service name. + * @param region The AWS region for signing. + * @param signer The signer implementation. + * @param awsCredentialsProvider The credentials provider for metadata access. + */ + public AWSRequestSigV4ASigningApacheInterceptor(String service, String region, Signer signer, AWSCredentialsProvider awsCredentialsProvider) { + this.service = service; + this.region = region; + this.signer = signer; + this.awsCredentialsProvider = awsCredentialsProvider; + } + + /** + * Processes and signs an HTTP request, updating its headers with the signature. + * + * @param request the HTTP request to process and sign. + * @param context the HTTP context associated with the request. + * @throws IOException if an I/O error occurs during request processing. + */ + @Override + public void process(HttpRequest request, HttpContext context) throws IOException { + SdkHttpFullRequest requestToSign = buildSdkHttpRequest(request, context); + SdkHttpFullRequest signedRequest = signRequest(requestToSign); + updateRequestHeaders(request, signedRequest.headers()); + updateRequestEntity(request, signedRequest); + } + + /** + * Builds an {@link SdkHttpFullRequest} from the Apache {@link HttpRequest}. + * + * @param request the HTTP request to process and sign. + * @param context the HTTP context associated with the request. + * @return an SDK HTTP request ready to be signed. + * @throws IOException if an error occurs while building the request. + */ + private SdkHttpFullRequest buildSdkHttpRequest(HttpRequest request, HttpContext context) throws IOException { + URIBuilder uriBuilder = parseUri(request); + SdkHttpFullRequest.Builder builder = SdkHttpFullRequest.builder() + .method(SdkHttpMethod.fromValue(request.getRequestLine().getMethod())) + .protocol(HTTPS_PROTOCOL) + .port(HTTPS_PORT) + .headers(headerArrayToMap(request.getAllHeaders())) + .rawQueryParameters(nvpToMapParams(uriBuilder.getQueryParams())); + + HttpHost host = (HttpHost) context.getAttribute(HTTP_TARGET_HOST); + if (host == null) { + throw new InvalidStateException("Host must not be null"); + } + builder.host(host.getHostName()); + try { + builder.encodedPath(uriBuilder.build().getRawPath()); + } catch (URISyntaxException e) { + throw new IOException("Invalid URI", e); + } + setRequestEntity(request, builder); + return builder.build(); + } + + /** + * Sets the request entity for the {@link SdkHttpFullRequest.Builder} if the original request contains an entity. + * This is used for requests that have a body, such as POST or PUT requests. + * + * @param request the original HTTP request. + * @param builder the SDK HTTP request builder. + */ + private void setRequestEntity(HttpRequest request, SdkHttpFullRequest.Builder builder) { + if (request instanceof HttpEntityEnclosingRequest) { + HttpEntity entity = ((HttpEntityEnclosingRequest) request).getEntity(); + if (entity != null) { + builder.contentStreamProvider(() -> { + try { + return entity.getContent(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + }); + } + } + } + + private URIBuilder parseUri(HttpRequest request) throws IOException { + try { + return new URIBuilder(request.getRequestLine().getUri()); + } catch (URISyntaxException e) { + throw new IOException("Invalid URI", e); + } + } + + /** + * Signs the given SDK HTTP request using the provided AWS credentials and signer. + * + * @param request the SDK HTTP request to sign. + * @return a signed SDK HTTP request. + */ + private SdkHttpFullRequest signRequest(SdkHttpFullRequest request) { + AWSSessionCredentials sessionCredentials = (AWSSessionCredentials) awsCredentialsProvider.getCredentials(); + AwsSessionCredentials awsCredentials = AwsSessionCredentials.create( + sessionCredentials.getAWSAccessKeyId(), + sessionCredentials.getAWSSecretKey(), + sessionCredentials.getSessionToken() + ); + + ExecutionAttributes executionAttributes = new ExecutionAttributes() + .putAttribute(AwsSignerExecutionAttribute.AWS_CREDENTIALS, awsCredentials) + .putAttribute(AwsSignerExecutionAttribute.SERVICE_SIGNING_NAME, service) + .putAttribute(AwsSignerExecutionAttribute.SIGNING_REGION, Region.of(region)); + + try { + return signer.sign(request, executionAttributes); + } catch (Exception e) { + LOG.log(Level.SEVERE, "Error Sigv4a signing the request", e); + throw e; + } + } + + /** + * Updates the HTTP request headers with the signed headers. + * + * @param request the original HTTP request. + * @param signedHeaders the headers after signing. + */ + private void updateRequestHeaders(HttpRequest request, Map> signedHeaders) { + Header[] headers = convertHeaderMapToArray(signedHeaders); + request.setHeaders(headers); + } + + /** + * Updates the request entity based on the signed request. This is used to update the request body after signing. + * + * @param request the original HTTP request. + * @param signedRequest the signed SDK HTTP request. + */ + private void updateRequestEntity(HttpRequest request, SdkHttpFullRequest signedRequest) { + if (request instanceof HttpEntityEnclosingRequest) { + HttpEntityEnclosingRequest httpEntityEnclosingRequest = (HttpEntityEnclosingRequest) request; + signedRequest.contentStreamProvider().ifPresent(provider -> { + InputStream contentStream = provider.newStream(); + BasicHttpEntity basicHttpEntity = new BasicHttpEntity(); + basicHttpEntity.setContent(contentStream); + signedRequest.firstMatchingHeader("Content-Length").ifPresent(value -> + basicHttpEntity.setContentLength(Long.parseLong(value))); + signedRequest.firstMatchingHeader("Content-Type").ifPresent(basicHttpEntity::setContentType); + httpEntityEnclosingRequest.setEntity(basicHttpEntity); + }); + } + } + + /** + * Converts an array of {@link Header} objects into a map, consolidating multiple values for the same header name. + * + * @param headers the array of {@link Header} objects to convert. + * @return a map where each key is a header name and each value is a list of header values. + */ + private static Map> headerArrayToMap(final Header[] headers) { + Map> headersMap = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); + for (Header header : headers) { + if (!skipHeader(header)) { + headersMap.computeIfAbsent(header.getName(), k -> new ArrayList<>()).add(header.getValue()); + } + } + return headersMap; + } + + /** + * Converts a map of headers back into an array of {@link Header} objects. + * + * @param mapHeaders the map of headers to convert. + * @return an array of {@link Header} objects. + */ + private Header[] convertHeaderMapToArray(final Map> mapHeaders) { + return mapHeaders.entrySet().stream() + .map(entry -> new BasicHeader(entry.getKey(), String.join(",", entry.getValue()))) + .toArray(Header[]::new); + } +} \ No newline at end of file diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/auth/AWSRequestSigningApacheInterceptor.java b/flint-core/src/main/scala/org/opensearch/flint/core/auth/AWSRequestSigningApacheInterceptor.java index c11677c3f..dd5fd78bc 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/auth/AWSRequestSigningApacheInterceptor.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/auth/AWSRequestSigningApacheInterceptor.java @@ -126,7 +126,7 @@ public void process(final HttpRequest request, final HttpContext context) * @param params list of HTTP query params as NameValuePairs * @return a multimap of HTTP query params */ - private static Map> nvpToMapParams(final List params) { + static Map> nvpToMapParams(final List params) { Map> parameterMap = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); for (NameValuePair nvp : params) { List argsList = @@ -154,10 +154,11 @@ private static Map headerArrayToMap(final Header[] headers) { * @param header header line to check * @return true if the given header should be excluded when signing */ - private static boolean skipHeader(final Header header) { + static boolean skipHeader(final Header header) { return ("content-length".equalsIgnoreCase(header.getName()) - && "0".equals(header.getValue())) // Strip Content-Length: 0 - || "host".equalsIgnoreCase(header.getName()); // Host comes from endpoint + && "0".equals(header.getValue())) // Strip Content-Length: 0 + || "host".equalsIgnoreCase(header.getName()) // Host comes from endpoint + || "connection".equalsIgnoreCase(header.getName()); // Skip setting Connection manually } /** diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/auth/ResourceBasedAWSRequestSigningApacheInterceptor.java b/flint-core/src/main/scala/org/opensearch/flint/core/auth/ResourceBasedAWSRequestSigningApacheInterceptor.java index c3e65fef3..b69343730 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/auth/ResourceBasedAWSRequestSigningApacheInterceptor.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/auth/ResourceBasedAWSRequestSigningApacheInterceptor.java @@ -1,12 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.flint.core.auth; +import com.amazonaws.auth.AWS4Signer; import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.auth.Signer; import org.apache.http.HttpException; import org.apache.http.HttpRequest; import org.apache.http.HttpRequestInterceptor; import org.apache.http.client.utils.URIBuilder; import org.apache.http.protocol.HttpContext; +import org.jetbrains.annotations.TestOnly; +import org.opensearch.common.Strings; +import software.amazon.awssdk.authcrt.signer.AwsCrtV4aSigner; import java.io.IOException; import java.net.URISyntaxException; @@ -19,33 +27,45 @@ public class ResourceBasedAWSRequestSigningApacheInterceptor implements HttpRequ private final String service; private final String metadataAccessIdentifier; - final AWSRequestSigningApacheInterceptor primaryInterceptor; - final AWSRequestSigningApacheInterceptor metadataAccessInterceptor; + final HttpRequestInterceptor primaryInterceptor; + final HttpRequestInterceptor metadataAccessInterceptor; /** * Constructs an interceptor for AWS request signing with optional metadata access. * * @param service The AWS service name. - * @param signer The AWS request signer. + * @param region The AWS region for signing. * @param primaryCredentialsProvider The credentials provider for general access. * @param metadataAccessCredentialsProvider The credentials provider for metadata access. * @param metadataAccessIdentifier Identifier for operations requiring metadata access. */ public ResourceBasedAWSRequestSigningApacheInterceptor(final String service, - final Signer signer, + final String region, final AWSCredentialsProvider primaryCredentialsProvider, final AWSCredentialsProvider metadataAccessCredentialsProvider, final String metadataAccessIdentifier) { - this(service, - new AWSRequestSigningApacheInterceptor(service, signer, primaryCredentialsProvider), - new AWSRequestSigningApacheInterceptor(service, signer, metadataAccessCredentialsProvider), - metadataAccessIdentifier); + if (Strings.isNullOrEmpty(service)) { + throw new IllegalArgumentException("Service name must not be null or empty."); + } + if (Strings.isNullOrEmpty(region)) { + throw new IllegalArgumentException("Region must not be null or empty."); + } + this.service = service; + this.metadataAccessIdentifier = metadataAccessIdentifier; + AWS4Signer signer = new AWS4Signer(); + signer.setServiceName(service); + signer.setRegionName(region); + this.primaryInterceptor = new AWSRequestSigningApacheInterceptor(service, signer, primaryCredentialsProvider); + this.metadataAccessInterceptor = primaryCredentialsProvider.equals(metadataAccessCredentialsProvider) + ? this.primaryInterceptor + : new AWSRequestSigV4ASigningApacheInterceptor(service, region, AwsCrtV4aSigner.builder().build(), metadataAccessCredentialsProvider); } // Test constructor allowing injection of mock interceptors + @TestOnly ResourceBasedAWSRequestSigningApacheInterceptor(final String service, - final AWSRequestSigningApacheInterceptor primaryInterceptor, - final AWSRequestSigningApacheInterceptor metadataAccessInterceptor, + final HttpRequestInterceptor primaryInterceptor, + final HttpRequestInterceptor metadataAccessInterceptor, final String metadataAccessIdentifier) { this.service = service == null ? "unknown" : service; this.primaryInterceptor = primaryInterceptor; @@ -94,6 +114,6 @@ private String parseUriToPath(HttpRequest request) throws IOException { * @return true if the operation requires metadata access credentials, false otherwise. */ private boolean isMetadataAccess(String resourcePath) { - return resourcePath.contains(metadataAccessIdentifier); + return !Strings.isNullOrEmpty(metadataAccessIdentifier) && resourcePath.contains(metadataAccessIdentifier); } } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java index b03ac0c6f..e71e3ded5 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java @@ -7,7 +7,6 @@ import static org.opensearch.common.xcontent.DeprecationHandler.IGNORE_DEPRECATIONS; -import com.amazonaws.auth.AWS4Signer; import com.amazonaws.auth.AWSCredentialsProvider; import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; import java.io.IOException; @@ -68,6 +67,9 @@ public class FlintOpenSearchClient implements FlintClient { private static final Logger LOG = Logger.getLogger(FlintOpenSearchClient.class.getName()); + private static final String SERVICE_NAME = "es"; + + /** * {@link NamedXContentRegistry} from {@link SearchModule} used for construct {@link QueryBuilder} from DSL query string. */ @@ -98,8 +100,7 @@ public FlintOpenSearchClient(FlintOptions options) { public OptimisticTransaction startTransaction( String indexName, String dataSourceName, boolean forceInit) { LOG.info("Starting transaction on index " + indexName + " and data source " + dataSourceName); - String metaLogIndexName = dataSourceName.isEmpty() ? META_LOG_NAME_PREFIX - : META_LOG_NAME_PREFIX + "_" + dataSourceName; + String metaLogIndexName = constructMetaLogIndexName(dataSourceName); try (IRestHighLevelClient client = createClient()) { if (client.doesIndexExist(new GetIndexRequest(metaLogIndexName), RequestOptions.DEFAULT)) { LOG.info("Found metadata log index " + metaLogIndexName); @@ -257,10 +258,6 @@ public IRestHighLevelClient createClient() { // SigV4 support if (options.getAuth().equals(FlintOptions.SIGV4_AUTH)) { - AWS4Signer signer = new AWS4Signer(); - signer.setServiceName("es"); - signer.setRegionName(options.getRegion()); - // Use DefaultAWSCredentialsProviderChain by default. final AtomicReference customAWSCredentialsProvider = new AtomicReference<>(new DefaultAWSCredentialsProviderChain()); @@ -274,6 +271,10 @@ public IRestHighLevelClient createClient() { String metadataAccessProviderClass = options.getMetadataAccessAwsCredentialsProvider(); final AtomicReference metadataAccessAWSCredentialsProvider = new AtomicReference<>(new DefaultAWSCredentialsProviderChain()); + + String metaLogIndexName = constructMetaLogIndexName(options.getDataSourceName()); + String systemIndexName = Strings.isNullOrEmpty(options.getSystemIndexName()) ? metaLogIndexName : options.getSystemIndexName(); + if (Strings.isNullOrEmpty(metadataAccessProviderClass)) { metadataAccessAWSCredentialsProvider.set(customAWSCredentialsProvider.get()); } else { @@ -283,7 +284,7 @@ public IRestHighLevelClient createClient() { restClientBuilder.setHttpClientConfigCallback(builder -> { HttpAsyncClientBuilder delegate = builder.addInterceptorLast( new ResourceBasedAWSRequestSigningApacheInterceptor( - signer.getServiceName(), signer, customAWSCredentialsProvider.get(), metadataAccessAWSCredentialsProvider.get(), options.getSystemIndexName())); + SERVICE_NAME, options.getRegion(), customAWSCredentialsProvider.get(), metadataAccessAWSCredentialsProvider.get(), systemIndexName)); return RetryableHttpAsyncClient.builder(delegate, options); } ); @@ -385,4 +386,8 @@ private String sanitizeIndexName(String indexName) { String encoded = percentEncode(indexName); return toLowercase(encoded); } + + private String constructMetaLogIndexName(String dataSourceName) { + return dataSourceName.isEmpty() ? META_LOG_NAME_PREFIX : META_LOG_NAME_PREFIX + "_" + dataSourceName; + } } diff --git a/flint-core/src/test/scala/org/opensearch/flint/core/auth/AWSRequestSigV4ASigningApacheInterceptorTest.java b/flint-core/src/test/scala/org/opensearch/flint/core/auth/AWSRequestSigV4ASigningApacheInterceptorTest.java new file mode 100644 index 000000000..9045fa9c1 --- /dev/null +++ b/flint-core/src/test/scala/org/opensearch/flint/core/auth/AWSRequestSigV4ASigningApacheInterceptorTest.java @@ -0,0 +1,148 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.auth; + +import org.apache.http.HttpHost; +import org.apache.http.HttpRequest; +import org.apache.http.message.BasicHttpRequest; +import org.apache.http.protocol.HttpContext; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import com.amazonaws.auth.AWSCredentialsProvider; +import com.amazonaws.auth.AWSSessionCredentials; + +import software.amazon.awssdk.core.signer.Signer; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.SdkHttpMethod; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; + +import static org.apache.http.protocol.HttpCoreContext.HTTP_TARGET_HOST; +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +public class AWSRequestSigV4ASigningApacheInterceptorTest { + @Mock + private Signer mockSigner; + @Mock + private AWSCredentialsProvider mockCredentialsProvider; + @Mock + private HttpContext mockContext; + @Mock + private AWSSessionCredentials mockSessionCredentials; + @Mock + private SdkHttpFullRequest mockSdkHttpFullRequest; + @Captor + private ArgumentCaptor sdkHttpFullRequestCaptor; + + private AWSRequestSigV4ASigningApacheInterceptor interceptor; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + when(mockCredentialsProvider.getCredentials()).thenReturn(mockSessionCredentials); + when(mockSessionCredentials.getAWSAccessKeyId()).thenReturn("ACCESS_KEY_ID"); + when(mockSessionCredentials.getAWSSecretKey()).thenReturn("SECRET_ACCESS_KEY"); + when(mockSessionCredentials.getSessionToken()).thenReturn("SESSION_TOKEN"); + interceptor = new AWSRequestSigV4ASigningApacheInterceptor("s3", "us-west-2", mockSigner, mockCredentialsProvider); + when(mockContext.getAttribute(HTTP_TARGET_HOST)).thenReturn(new HttpHost("localhost", 443, "https")); + when(mockSigner.sign(any(), any())).thenReturn(mockSdkHttpFullRequest); + } + + @Test + public void testSigningProcess() throws Exception { + HttpRequest request = new BasicHttpRequest("GET", "/path/to/resource"); + interceptor.process(request, mockContext); + + verify(mockSigner).sign(sdkHttpFullRequestCaptor.capture(), any()); + SdkHttpFullRequest signedRequest = sdkHttpFullRequestCaptor.getValue(); + + assertEquals(SdkHttpMethod.GET, signedRequest.method()); + assertEquals("/path/to/resource", signedRequest.encodedPath()); + } + + @Test(expected = IOException.class) + public void testInvalidUriHandling() throws Exception { + HttpRequest request = new BasicHttpRequest("GET", ":///this/is/not/a/valid/uri"); + interceptor.process(request, mockContext); + } + + @Test + public void testHeaderUpdateAfterSigning() throws Exception { + // Setup mock signer to return a new SdkHttpFullRequest with an "Authorization" header + when(mockSigner.sign(any(SdkHttpFullRequest.class), any())).thenAnswer(invocation -> { + SdkHttpFullRequest originalRequest = invocation.getArgument(0); + Map> modifiedHeaders = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); + modifiedHeaders.putAll(originalRequest.headers()); + modifiedHeaders.put("Authorization", List.of("AWS4-ECDSA-P256-SHA256 Credential=...")); + + // Build a new SdkHttpFullRequest with the modified headers + return SdkHttpFullRequest.builder() + .method(originalRequest.method()) + .uri(originalRequest.getUri()) + .headers(modifiedHeaders) + .build(); + }); + + HttpRequest request = new BasicHttpRequest("GET", "/path/to/resource"); + interceptor.process(request, mockContext); + + // Now verify that the HttpRequest has been updated with the new headers from the signed request + assertTrue("The request does not contain the expected 'Authorization' header", + request.containsHeader("Authorization")); + assertEquals("AWS4-ECDSA-P256-SHA256 Credential=...", + request.getFirstHeader("Authorization").getValue()); + } + + @Test + public void testSigningProcessWithCorrectHostFormat() throws Exception { + HttpRequest request = new BasicHttpRequest("GET", "/path/to/resource"); + + // Setup the interceptor with a mock HTTP context to return an HttpHost with scheme and port + HttpHost expectedHost = new HttpHost("localhost", 443, "https"); + when(mockContext.getAttribute(HTTP_TARGET_HOST)).thenReturn(expectedHost); + + interceptor.process(request, mockContext); + + // Capture the SdkHttpFullRequest passed to the signer + verify(mockSigner).sign(sdkHttpFullRequestCaptor.capture(), any()); + SdkHttpFullRequest signedRequest = sdkHttpFullRequestCaptor.getValue(); + + // Assert method and path + assertEquals(SdkHttpMethod.GET, signedRequest.method()); + assertEquals("/path/to/resource", signedRequest.encodedPath()); + + // Verify the host format is correct (hostname only, without scheme and port) + String expectedHostName = "localhost"; // Expected hostname without scheme and port + assertEquals("The host in the signed request should contain only the hostname, without scheme and port.", + expectedHostName, signedRequest.host()); + } + + @Test + public void testConnectionHeaderIsSkippedDuringSigning() throws Exception { + // Create a new HttpRequest with a Connection header + HttpRequest request = new BasicHttpRequest("GET", "/path/to/resource"); + request.addHeader("Connection", "keep-alive"); + + interceptor.process(request, mockContext); + + // Verify that the SdkHttpFullRequest passed to the signer does not contain the Connection header + verify(mockSigner).sign(sdkHttpFullRequestCaptor.capture(), any()); + SdkHttpFullRequest signedRequest = sdkHttpFullRequestCaptor.getValue(); + + // Assert that the signed request does not have the Connection header + assertFalse("The signed request should not contain a 'Connection' header.", + signedRequest.headers().containsKey("Connection")); + } +} \ No newline at end of file diff --git a/flint-core/src/test/scala/org/opensearch/flint/core/auth/ResourceBasedAWSRequestSigningApacheInterceptorTest.java b/flint-core/src/test/scala/org/opensearch/flint/core/auth/ResourceBasedAWSRequestSigningApacheInterceptorTest.java index 0ef021b53..e3edf3f73 100644 --- a/flint-core/src/test/scala/org/opensearch/flint/core/auth/ResourceBasedAWSRequestSigningApacheInterceptorTest.java +++ b/flint-core/src/test/scala/org/opensearch/flint/core/auth/ResourceBasedAWSRequestSigningApacheInterceptorTest.java @@ -1,8 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.flint.core.auth; import com.amazonaws.auth.AWSCredentialsProvider; +import com.amazonaws.auth.AWSSessionCredentials; import com.amazonaws.auth.Signer; +import org.apache.http.HttpHost; import org.apache.http.HttpRequest; +import org.apache.http.HttpRequestInterceptor; import org.apache.http.message.BasicHttpRequest; import org.apache.http.protocol.HttpContext; import org.junit.Before; @@ -11,19 +19,25 @@ import org.mockito.Captor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import static org.apache.http.protocol.HttpCoreContext.HTTP_TARGET_HOST; import static org.mockito.Mockito.*; public class ResourceBasedAWSRequestSigningApacheInterceptorTest { @Mock - private Signer mockSigner; + private Signer mockPrimarySigner; + @Mock + private software.amazon.awssdk.core.signer.Signer mockSigV4ASigner; @Mock private AWSCredentialsProvider mockPrimaryCredentialsProvider; @Mock private AWSCredentialsProvider mockMetadataAccessCredentialsProvider; @Mock private HttpContext mockContext; + @Mock + private SdkHttpFullRequest mockSdkHttpFullRequest; @Captor private ArgumentCaptor httpRequestCaptor; @@ -32,8 +46,20 @@ public class ResourceBasedAWSRequestSigningApacheInterceptorTest { @Before public void setUp() { MockitoAnnotations.initMocks(this); - AWSRequestSigningApacheInterceptor primaryInterceptorSpy = spy(new AWSRequestSigningApacheInterceptor("es", mockSigner, mockPrimaryCredentialsProvider)); - AWSRequestSigningApacheInterceptor metadataInterceptorSpy = spy(new AWSRequestSigningApacheInterceptor("es", mockSigner, mockMetadataAccessCredentialsProvider)); + + // Configure the mockMetadataAccessCredentialsProvider to return the mock credentials + AWSSessionCredentials mockSessionCredentials = mock(AWSSessionCredentials.class); + when(mockSessionCredentials.getAWSAccessKeyId()).thenReturn("accessKey"); + when(mockSessionCredentials.getAWSSecretKey()).thenReturn("secretKey"); + when(mockSessionCredentials.getSessionToken()).thenReturn("sessionToken"); + when(mockMetadataAccessCredentialsProvider.getCredentials()).thenReturn(mockSessionCredentials); + + HttpRequestInterceptor primaryInterceptorSpy = spy(new AWSRequestSigningApacheInterceptor("es", mockPrimarySigner, mockPrimaryCredentialsProvider)); + HttpRequestInterceptor metadataInterceptorSpy = spy(new AWSRequestSigV4ASigningApacheInterceptor("es", "us-east-1", mockSigV4ASigner, mockMetadataAccessCredentialsProvider)); + + // Configure the mockMetadataAccessCredentialsProvider to avoid NPEs + when(mockContext.getAttribute(HTTP_TARGET_HOST)).thenReturn(new HttpHost("http://localhost")); + when(mockSigV4ASigner.sign(any(), any())).thenReturn(mockSdkHttpFullRequest); interceptor = new ResourceBasedAWSRequestSigningApacheInterceptor( "es", @@ -45,6 +71,7 @@ public void setUp() { @Test public void testProcessWithMetadataAccess() throws Exception { HttpRequest request = new BasicHttpRequest("GET", "/es/metadata/resource"); + request.addHeader("Content-Type", "application/json"); interceptor.process(request, mockContext); diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 5a0918d4a..1016adaba 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -28,8 +28,8 @@ import org.opensearch.search.sort.SortOrder import org.apache.spark.SparkConf import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.flint.config.FlintSparkConf -import org.apache.spark.sql.util.{DefaultShutdownHookManager, ShutdownHookManagerTrait} import org.apache.spark.util.ThreadUtils /** @@ -136,12 +136,22 @@ object FlintREPL extends Logging with FlintJobExecutor { val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex.get) val sessionTimerContext = getTimerContext(MetricConstants.REPL_PROCESSING_TIME_METRIC) - addShutdownHook( - flintSessionIndexUpdater, - osClient, - sessionIndex.get, - sessionId.get, - sessionTimerContext) + /** + * Transition the session update logic from {@link + * org.apache.spark.util.ShutdownHookManager} to {@link SparkListenerApplicationEnd}. This + * change helps prevent interruptions to asynchronous SigV4A signing during REPL shutdown. + * + * Cancelling an EMR job directly when SigV4a signer in use could otherwise lead to stale + * sessions. For tracking, see the GitHub issue: + * https://github.com/opensearch-project/opensearch-spark/issues/320 + */ + spark.sparkContext.addSparkListener( + new PreShutdownListener( + flintSessionIndexUpdater, + osClient, + sessionIndex.get, + sessionId.get, + sessionTimerContext)) // 1 thread for updating heart beat val threadPool = @@ -899,16 +909,18 @@ object FlintREPL extends Logging with FlintJobExecutor { flintReader } - def addShutdownHook( + class PreShutdownListener( flintSessionIndexUpdater: OpenSearchUpdater, osClient: OSClient, sessionIndex: String, sessionId: String, - sessionTimerContext: Timer.Context, - shutdownHookManager: ShutdownHookManagerTrait = DefaultShutdownHookManager): Unit = { + sessionTimerContext: Timer.Context) + extends SparkListener + with Logging { - shutdownHookManager.addShutdownHook(() => { + override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { logInfo("Shutting down REPL") + logInfo("earlyExitFlag: " + earlyExitFlag) val getResponse = osClient.getDoc(sessionIndex, sessionId) if (!getResponse.isExists()) { return @@ -936,7 +948,7 @@ object FlintREPL extends Logging with FlintJobExecutor { sessionId, sessionTimerContext) } - }) + } } private def updateFlintInstanceBeforeShutdown( @@ -1124,6 +1136,7 @@ object FlintREPL extends Logging with FlintJobExecutor { } private def recordSessionSuccess(sessionTimerContext: Timer.Context): Unit = { + logInfo("Session Success") stopTimer(sessionTimerContext) if (sessionRunningCount.get() > 0) { sessionRunningCount.decrementAndGet() diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index 288eeb7c5..ea789c161 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -16,7 +16,6 @@ import scala.concurrent.duration.{Duration, MINUTES} import scala.reflect.runtime.universe.TypeTag import com.codahale.metrics.Timer -import org.mockito.ArgumentMatchers.{eq => eqTo, _} import org.mockito.ArgumentMatchersSugar import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock @@ -28,6 +27,8 @@ import org.opensearch.search.sort.SortOrder import org.scalatestplus.mockito.MockitoSugar import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.scheduler.SparkListenerApplicationEnd +import org.apache.spark.sql.FlintREPL.PreShutdownListener import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.flint.config.FlintSparkConf @@ -162,17 +163,18 @@ class FlintREPLTest verify(flintSessionUpdater, atLeastOnce()).upsert(eqTo("session1"), *) } - test("createShutdownHook add shutdown hook and update FlintInstance if conditions are met") { - val flintSessionIndexUpdater = mock[OpenSearchUpdater] + test("PreShutdownListener updates FlintInstance if conditions are met") { + // Mock dependencies val osClient = mock[OSClient] val getResponse = mock[GetResponse] + val flintSessionIndexUpdater = mock[OpenSearchUpdater] val sessionIndex = "testIndex" val sessionId = "testSessionId" - val flintSessionContext = mock[Timer.Context] + val timerContext = mock[Timer.Context] + // Setup the getDoc to return a document indicating the session is running when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) - when(getResponse.getSourceAsMap).thenReturn( Map[String, Object]( "applicationId" -> "app1", @@ -184,22 +186,18 @@ class FlintREPLTest "state" -> "running", "jobStartTime" -> java.lang.Long.valueOf(0L)).asJava) - val mockShutdownHookManager = new ShutdownHookManagerTrait { - override def addShutdownHook(hook: () => Unit): AnyRef = { - hook() // execute the hook immediately - new Object() // return a dummy AnyRef as per the method signature - } - } - - // Here, we're injecting our mockShutdownHookManager into the method - FlintREPL.addShutdownHook( + // Instantiate the listener + val listener = new PreShutdownListener( flintSessionIndexUpdater, osClient, sessionIndex, sessionId, - flintSessionContext, - mockShutdownHookManager) + timerContext) + + // Simulate application end + listener.onApplicationEnd(SparkListenerApplicationEnd(System.currentTimeMillis())) + // Verify the update is called with the correct arguments verify(flintSessionIndexUpdater).updateIf(*, *, *, *) }