diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java b/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java index c49247f37..9be01737c 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java @@ -33,16 +33,20 @@ public class FlintOptions implements Serializable { public static final String SCHEME = "scheme"; - public static final String AUTH = "auth"; + /** + * Service name used for SigV4 signature. + * `es`: Amazon OpenSearch Service + * `aoss`: Amazon OpenSearch Serverless + */ + public static final String SERVICE_NAME = "auth.servicename"; + public static final String SERVICE_NAME_ES = "es"; + public static final String SERVICE_NAME_AOSS = "aoss"; + public static final String AUTH = "auth"; public static final String NONE_AUTH = "noauth"; - public static final String SIGV4_AUTH = "sigv4"; - public static final String BASIC_AUTH = "basic"; - public static final String USERNAME = "auth.username"; - public static final String PASSWORD = "auth.password"; public static final String CUSTOM_AWS_CREDENTIALS_PROVIDER = "customAWSCredentialsProvider"; @@ -131,6 +135,10 @@ public String getAuth() { return options.getOrDefault(AUTH, NONE_AUTH); } + public String getServiceName() { + return options.getOrDefault(SERVICE_NAME, SERVICE_NAME_ES); + } + public String getCustomAwsCredentialsProvider() { return options.getOrDefault(CUSTOM_AWS_CREDENTIALS_PROVIDER, DEFAULT_AWS_CREDENTIALS_PROVIDER); } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/auth/AWSRequestSigningApacheInterceptor.java b/flint-core/src/main/scala/org/opensearch/flint/core/auth/AWSRequestSigningApacheInterceptor.java index a3925999e..172ac5ceb 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/auth/AWSRequestSigningApacheInterceptor.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/auth/AWSRequestSigningApacheInterceptor.java @@ -5,7 +5,9 @@ package org.opensearch.flint.core.auth; +import static com.amazonaws.auth.internal.SignerConstants.X_AMZ_CONTENT_SHA256; import static org.apache.http.protocol.HttpCoreContext.HTTP_TARGET_HOST; +import static org.opensearch.flint.core.FlintOptions.SERVICE_NAME_AOSS; import com.amazonaws.DefaultRequest; import com.amazonaws.auth.AWSCredentialsProvider; @@ -31,6 +33,7 @@ import org.apache.http.entity.BasicHttpEntity; import org.apache.http.message.BasicHeader; import org.apache.http.protocol.HttpContext; +import org.opensearch.flint.core.storage.OpenSearchClientUtils; /** * From https://github.com/opensearch-project/sql-jdbc/blob/main/src/main/java/org/opensearch/jdbc/transport/http/auth/aws/AWSRequestSigningApacheInterceptor.java @@ -74,13 +77,6 @@ public AWSRequestSigningApacheInterceptor(final String service, @Override public void process(final HttpRequest request, final HttpContext context) throws HttpException, IOException { - URIBuilder uriBuilder; - try { - uriBuilder = new URIBuilder(request.getRequestLine().getUri()); - } catch (URISyntaxException e) { - throw new IOException("Invalid URI" , e); - } - // Copy Apache HttpRequest to AWS DefaultRequest DefaultRequest signableRequest = new DefaultRequest<>(service); @@ -91,7 +87,10 @@ public void process(final HttpRequest request, final HttpContext context) final HttpMethodName httpMethod = HttpMethodName.fromValue(request.getRequestLine().getMethod()); signableRequest.setHttpMethod(httpMethod); + + URIBuilder uriBuilder; try { + uriBuilder = new URIBuilder(request.getRequestLine().getUri()); signableRequest.setResourcePath(uriBuilder.build().getRawPath()); } catch (URISyntaxException e) { throw new IOException("Invalid URI" , e); @@ -110,6 +109,10 @@ public void process(final HttpRequest request, final HttpContext context) signableRequest.setParameters(nvpToMapParams(uriBuilder.getQueryParams())); signableRequest.setHeaders(headerArrayToMap(request.getAllHeaders())); + if (SERVICE_NAME_AOSS.equals(service)) { + enableContentBodySignature(signableRequest); + } + // Sign it signer.sign(signableRequest, awsCredentialsProvider.getCredentials()); @@ -126,6 +129,11 @@ public void process(final HttpRequest request, final HttpContext context) } } + private void enableContentBodySignature(DefaultRequest signableRequest) { + // AWS4Signer will add `x-amz-content-sha256` header when this header is set + signableRequest.addHeader(X_AMZ_CONTENT_SHA256, "required"); + } + /** * * @param params list of HTTP query params as NameValuePairs diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/auth/ResourceBasedAWSRequestSigningApacheInterceptor.java b/flint-core/src/main/scala/org/opensearch/flint/core/auth/ResourceBasedAWSRequestSigningApacheInterceptor.java index b69343730..05b83d658 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/auth/ResourceBasedAWSRequestSigningApacheInterceptor.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/auth/ResourceBasedAWSRequestSigningApacheInterceptor.java @@ -5,6 +5,8 @@ package org.opensearch.flint.core.auth; +import static org.opensearch.flint.core.FlintOptions.SERVICE_NAME_ES; + import com.amazonaws.auth.AWS4Signer; import com.amazonaws.auth.AWSCredentialsProvider; import org.apache.http.HttpException; @@ -14,6 +16,7 @@ import org.apache.http.protocol.HttpContext; import org.jetbrains.annotations.TestOnly; import org.opensearch.common.Strings; +import org.opensearch.flint.core.storage.OpenSearchClientUtils; import software.amazon.awssdk.authcrt.signer.AwsCrtV4aSigner; import java.io.IOException; @@ -84,7 +87,7 @@ public ResourceBasedAWSRequestSigningApacheInterceptor(final String service, @Override public void process(HttpRequest request, HttpContext context) throws HttpException, IOException { String resourcePath = parseUriToPath(request); - if ("es".equals(this.service) && isMetadataAccess(resourcePath)) { + if (SERVICE_NAME_ES.equals(this.service) && isMetadataAccess(resourcePath)) { metadataAccessInterceptor.process(request, context); } else { primaryInterceptor.process(request, context); diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchClientUtils.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchClientUtils.java index 9277a17df..21241d7ab 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchClientUtils.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchClientUtils.java @@ -30,8 +30,6 @@ */ public class OpenSearchClientUtils { - private static final String SERVICE_NAME = "es"; - /** * Metadata log index name prefix */ @@ -90,7 +88,7 @@ private static RestClientBuilder configureSigV4Auth(RestClientBuilder restClient restClientBuilder.setHttpClientConfigCallback(builder -> { HttpAsyncClientBuilder delegate = builder.addInterceptorLast( new ResourceBasedAWSRequestSigningApacheInterceptor( - SERVICE_NAME, options.getRegion(), customAWSCredentialsProvider.get(), metadataAccessAWSCredentialsProvider.get(), systemIndexName)); + options.getServiceName(), options.getRegion(), customAWSCredentialsProvider.get(), metadataAccessAWSCredentialsProvider.get(), systemIndexName)); return RetryableHttpAsyncClient.builder(delegate, options); } ); diff --git a/flint-core/src/test/scala/org/opensearch/flint/core/auth/AWSRequestSigningApacheInterceptorTest.java b/flint-core/src/test/scala/org/opensearch/flint/core/auth/AWSRequestSigningApacheInterceptorTest.java new file mode 100644 index 000000000..ae8fdfa9a --- /dev/null +++ b/flint-core/src/test/scala/org/opensearch/flint/core/auth/AWSRequestSigningApacheInterceptorTest.java @@ -0,0 +1,126 @@ +package org.opensearch.flint.core.auth; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static software.amazon.awssdk.auth.signer.internal.SignerConstant.X_AMZ_CONTENT_SHA256; + +import com.amazonaws.DefaultRequest; +import com.amazonaws.auth.AWSCredentials; +import com.amazonaws.auth.AWSCredentialsProvider; +import com.amazonaws.auth.Signer; +import com.amazonaws.http.HttpMethodName; +import com.amazonaws.util.IOUtils; +import java.io.IOException; +import java.net.URI; +import org.apache.http.HttpHost; +import org.apache.http.entity.BasicHttpEntity; +import org.apache.http.message.BasicHttpEntityEnclosingRequest; +import org.apache.http.protocol.BasicHttpContext; +import org.apache.http.protocol.HttpCoreContext; +import org.jetbrains.annotations.NotNull; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import software.amazon.awssdk.utils.StringInputStream; + +@ExtendWith(MockitoExtension.class) +class AWSRequestSigningApacheInterceptorTest { + + @Mock + AWSCredentialsProvider awsCredentialsProvider; + @Mock Signer signer; + @Mock + AWSCredentials awsCredentials; + + @Captor + ArgumentCaptor> signableRequestCaptor; + + @Test + public void testProcessWithServiceIsEs() throws Exception { + AWSRequestSigningApacheInterceptor awsRequestSigningApacheInterceptor = new AWSRequestSigningApacheInterceptor("es", signer, awsCredentialsProvider); + final BasicHttpEntityEnclosingRequest request = getRequestWithEntity(); + final BasicHttpContext context = getContext(); + when(awsCredentialsProvider.getCredentials()).thenReturn(awsCredentials); + + awsRequestSigningApacheInterceptor.process(request, context); + + verify(signer).sign(signableRequestCaptor.capture(), eq(awsCredentials)); + DefaultRequest signableRequest = signableRequestCaptor.getValue(); + assertEquals(new URI("http://hello.world"), signableRequest.getEndpoint()); + assertEquals(HttpMethodName.POST, signableRequest.getHttpMethod()); + assertEquals("/path", signableRequest.getResourcePath()); + assertEquals("ENTITY", IOUtils.toString(signableRequest.getContent())); + assertEquals("HeaderValue", signableRequest.getHeaders().get("Test-Header")); + assertEquals("value0", signableRequest.getParameters().get("param0").get(0)); + } + + @Test + public void testProcessWithoutEntity() throws Exception { + AWSRequestSigningApacheInterceptor awsRequestSigningApacheInterceptor = new AWSRequestSigningApacheInterceptor("es", signer, awsCredentialsProvider); + final BasicHttpEntityEnclosingRequest request = getRequest(); + final BasicHttpContext context = getContext(); + when(awsCredentialsProvider.getCredentials()).thenReturn(awsCredentials); + + awsRequestSigningApacheInterceptor.process(request, context); + + verify(signer).sign(signableRequestCaptor.capture(), eq(awsCredentials)); + DefaultRequest signableRequest = signableRequestCaptor.getValue(); + assertEquals("", IOUtils.toString(signableRequest.getContent())); + } + + @NotNull + private static BasicHttpContext getContext() { + BasicHttpContext context = new BasicHttpContext(); + context.setAttribute(HttpCoreContext.HTTP_TARGET_HOST, new HttpHost("hello.world")); + return context; + } + + @Test + public void testProcessWithServiceIsAoss() throws Exception { + AWSRequestSigningApacheInterceptor awsRequestSigningApacheInterceptor = new AWSRequestSigningApacheInterceptor("aoss", signer, awsCredentialsProvider); + final BasicHttpEntityEnclosingRequest request = getRequest(); + final BasicHttpContext context = getContext(); + when(awsCredentialsProvider.getCredentials()).thenReturn(awsCredentials); + + awsRequestSigningApacheInterceptor.process(request, context); + + verify(signer).sign(signableRequestCaptor.capture(), eq(awsCredentials)); + DefaultRequest signableRequest = signableRequestCaptor.getValue(); + assertEquals("required", signableRequest.getHeaders().get(X_AMZ_CONTENT_SHA256)); + } + + @Test + public void testInvalidURI() throws Exception { + AWSRequestSigningApacheInterceptor awsRequestSigningApacheInterceptor = new AWSRequestSigningApacheInterceptor("aoss", signer, awsCredentialsProvider); + final BasicHttpEntityEnclosingRequest request = new BasicHttpEntityEnclosingRequest("POST", "::INVALID_URI::"); + final BasicHttpContext context = getContext(); + + assertThrows(IOException.class, () -> { + awsRequestSigningApacheInterceptor.process(request, context); + }); + } + + @NotNull + private static BasicHttpEntityEnclosingRequest getRequestWithEntity() { + BasicHttpEntityEnclosingRequest request = getRequest(); + BasicHttpEntity basicHttpEntity = new BasicHttpEntity(); + basicHttpEntity.setContent(new StringInputStream("ENTITY")); + request.setEntity(basicHttpEntity); + request.setHeader("content-length", "6"); + return request; + } + + @NotNull + private static BasicHttpEntityEnclosingRequest getRequest() { + BasicHttpEntityEnclosingRequest request = new BasicHttpEntityEnclosingRequest("POST", "https://hello.world/path?param0=value0"); + request.setHeader("Test-Header", "HeaderValue"); + request.setHeader("content-length", "0"); + return request; + } +} \ No newline at end of file diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala index f2f680281..7ea284959 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala @@ -57,6 +57,12 @@ object FlintSparkConf { "noauth(no auth), sigv4(sigv4 auth), basic(basic auth)") .createWithDefault(FlintOptions.NONE_AUTH) + val SERVICE_NAME = FlintConfig("spark.datasource.flint.auth.servicename") + .datasourceOption() + .doc("service name used for SigV4 signature. " + + "es (AWS OpenSearch Service), aoss (Amazon OpenSearch Serverless)") + .createWithDefault(FlintOptions.SERVICE_NAME_ES) + val USERNAME = FlintConfig("spark.datasource.flint.auth.username") .datasourceOption() .doc("basic auth username") @@ -267,6 +273,7 @@ case class FlintSparkConf(properties: JMap[String, String]) extends Serializable RETRYABLE_HTTP_STATUS_CODES, REGION, CUSTOM_AWS_CREDENTIALS_PROVIDER, + SERVICE_NAME, USERNAME, PASSWORD, SOCKET_TIMEOUT_MILLIS,