diff --git a/MAINTAINERS.md b/MAINTAINERS.md index 0f2193ce0..83fd25a0c 100644 --- a/MAINTAINERS.md +++ b/MAINTAINERS.md @@ -16,3 +16,4 @@ This document contains a list of maintainers in this repo. See [opensearch-proje | Sean Kao | [seankao-az](https://github.com/seankao-az) | Amazon | | Anirudha Jadhav | [anirudha](https://github.com/anirudha) | Amazon | | Kaituo Li | [kaituo](https://github.com/kaituo) | Amazon | +| Louis Chu | [noCharger](https://github.com/noCharger) | Amazon | diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java b/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java index ee38bbb9c..ee78aa512 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java @@ -102,9 +102,10 @@ OptimisticTransaction startTransaction(String indexName, String dataSourc * @return {@link FlintWriter} */ FlintWriter createWriter(String indexName); + /** * Create {@link IRestHighLevelClient}. * @return {@link IRestHighLevelClient} */ - public IRestHighLevelClient createClient(); + IRestHighLevelClient createClient(); } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java b/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java index 1282e1c94..9858ffd1e 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java @@ -5,7 +5,6 @@ package org.opensearch.flint.core; -import dev.failsafe.RetryPolicy; import java.io.Serializable; import java.util.Map; import org.opensearch.flint.core.http.FlintRetryOptions; @@ -46,10 +45,14 @@ public class FlintOptions implements Serializable { public static final String CUSTOM_AWS_CREDENTIALS_PROVIDER = "customAWSCredentialsProvider"; + public static final String METADATA_ACCESS_AWS_CREDENTIALS_PROVIDER = "spark.metadata.accessAWSCredentialsProvider"; + /** - * By default, customAWSCredentialsProvider is empty. use DefaultAWSCredentialsProviderChain. + * By default, customAWSCredentialsProvider and accessAWSCredentialsProvider are empty. use DefaultAWSCredentialsProviderChain. */ - public static final String DEFAULT_CUSTOM_AWS_CREDENTIALS_PROVIDER = ""; + public static final String DEFAULT_AWS_CREDENTIALS_PROVIDER = ""; + + public static final String SYSTEM_INDEX_KEY_NAME = "spark.flint.job.requestIndex"; /** * Used by {@link org.opensearch.flint.core.storage.OpenSearchScrollReader} @@ -121,7 +124,11 @@ public String getAuth() { } public String getCustomAwsCredentialsProvider() { - return options.getOrDefault(CUSTOM_AWS_CREDENTIALS_PROVIDER, ""); + return options.getOrDefault(CUSTOM_AWS_CREDENTIALS_PROVIDER, DEFAULT_AWS_CREDENTIALS_PROVIDER); + } + + public String getMetadataAccessAwsCredentialsProvider() { + return options.getOrDefault(METADATA_ACCESS_AWS_CREDENTIALS_PROVIDER, DEFAULT_AWS_CREDENTIALS_PROVIDER); } public String getUsername() { @@ -139,4 +146,8 @@ public int getSocketTimeoutMillis() { public String getDataSourceName() { return options.getOrDefault(DATA_SOURCE_NAME, ""); } + + public String getSystemIndexName() { + return options.getOrDefault(SYSTEM_INDEX_KEY_NAME, ""); + } } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/auth/ResourceBasedAWSRequestSigningApacheInterceptor.java b/flint-core/src/main/scala/org/opensearch/flint/core/auth/ResourceBasedAWSRequestSigningApacheInterceptor.java new file mode 100644 index 000000000..c3e65fef3 --- /dev/null +++ b/flint-core/src/main/scala/org/opensearch/flint/core/auth/ResourceBasedAWSRequestSigningApacheInterceptor.java @@ -0,0 +1,99 @@ +package org.opensearch.flint.core.auth; + +import com.amazonaws.auth.AWSCredentialsProvider; +import com.amazonaws.auth.Signer; +import org.apache.http.HttpException; +import org.apache.http.HttpRequest; +import org.apache.http.HttpRequestInterceptor; +import org.apache.http.client.utils.URIBuilder; +import org.apache.http.protocol.HttpContext; + +import java.io.IOException; +import java.net.URISyntaxException; + +/** + * Intercepts HTTP requests to sign them for AWS authentication, adjusting the signing process + * based on whether the request accesses metadata or not. + */ +public class ResourceBasedAWSRequestSigningApacheInterceptor implements HttpRequestInterceptor { + + private final String service; + private final String metadataAccessIdentifier; + final AWSRequestSigningApacheInterceptor primaryInterceptor; + final AWSRequestSigningApacheInterceptor metadataAccessInterceptor; + + /** + * Constructs an interceptor for AWS request signing with optional metadata access. + * + * @param service The AWS service name. + * @param signer The AWS request signer. + * @param primaryCredentialsProvider The credentials provider for general access. + * @param metadataAccessCredentialsProvider The credentials provider for metadata access. + * @param metadataAccessIdentifier Identifier for operations requiring metadata access. + */ + public ResourceBasedAWSRequestSigningApacheInterceptor(final String service, + final Signer signer, + final AWSCredentialsProvider primaryCredentialsProvider, + final AWSCredentialsProvider metadataAccessCredentialsProvider, + final String metadataAccessIdentifier) { + this(service, + new AWSRequestSigningApacheInterceptor(service, signer, primaryCredentialsProvider), + new AWSRequestSigningApacheInterceptor(service, signer, metadataAccessCredentialsProvider), + metadataAccessIdentifier); + } + + // Test constructor allowing injection of mock interceptors + ResourceBasedAWSRequestSigningApacheInterceptor(final String service, + final AWSRequestSigningApacheInterceptor primaryInterceptor, + final AWSRequestSigningApacheInterceptor metadataAccessInterceptor, + final String metadataAccessIdentifier) { + this.service = service == null ? "unknown" : service; + this.primaryInterceptor = primaryInterceptor; + this.metadataAccessInterceptor = metadataAccessInterceptor; + this.metadataAccessIdentifier = metadataAccessIdentifier; + } + + /** + * Processes an HTTP request, signing it according to whether it requires metadata access. + * + * @param request The HTTP request to process. + * @param context The context in which the HTTP request is being processed. + * @throws HttpException If processing the HTTP request results in an exception. + * @throws IOException If an I/O error occurs. + */ + @Override + public void process(HttpRequest request, HttpContext context) throws HttpException, IOException { + String resourcePath = parseUriToPath(request); + if ("es".equals(this.service) && isMetadataAccess(resourcePath)) { + metadataAccessInterceptor.process(request, context); + } else { + primaryInterceptor.process(request, context); + } + } + + /** + * Extracts and returns the path component of a URI from an HTTP request. + * + * @param request The HTTP request from which to extract the URI path. + * @return The path component of the URI. + * @throws IOException If an error occurs parsing the URI. + */ + private String parseUriToPath(HttpRequest request) throws IOException { + try { + URIBuilder uriBuilder = new URIBuilder(request.getRequestLine().getUri()); + return uriBuilder.build().getRawPath(); + } catch (URISyntaxException e) { + throw new IOException("Invalid URI", e); + } + } + + /** + * Determines whether the accessed resource requires metadata credentials. + * + * @param resourcePath The path of the resource being accessed. + * @return true if the operation requires metadata access credentials, false otherwise. + */ + private boolean isMetadataAccess(String resourcePath) { + return resourcePath.contains(metadataAccessIdentifier); + } +} diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java index 1c15af357..b03ac0c6f 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java @@ -46,7 +46,7 @@ import org.opensearch.flint.core.FlintClient; import org.opensearch.flint.core.FlintOptions; import org.opensearch.flint.core.IRestHighLevelClient; -import org.opensearch.flint.core.auth.AWSRequestSigningApacheInterceptor; +import org.opensearch.flint.core.auth.ResourceBasedAWSRequestSigningApacheInterceptor; import org.opensearch.flint.core.http.RetryableHttpAsyncClient; import org.opensearch.flint.core.metadata.FlintMetadata; import org.opensearch.flint.core.metadata.log.DefaultOptimisticTransaction; @@ -262,26 +262,30 @@ public IRestHighLevelClient createClient() { signer.setRegionName(options.getRegion()); // Use DefaultAWSCredentialsProviderChain by default. - final AtomicReference awsCredentialsProvider = - new AtomicReference<>(new DefaultAWSCredentialsProviderChain()); - String providerClass = options.getCustomAwsCredentialsProvider(); - if (!Strings.isNullOrEmpty(providerClass)) { - try { - Class awsCredentialsProviderClass = Class.forName(providerClass); - Constructor ctor = awsCredentialsProviderClass.getDeclaredConstructor(); - ctor.setAccessible(true); - awsCredentialsProvider.set((AWSCredentialsProvider) ctor.newInstance()); - } catch (Exception e) { - throw new RuntimeException(e); - } + final AtomicReference customAWSCredentialsProvider = + new AtomicReference<>(new DefaultAWSCredentialsProviderChain()); + String customProviderClass = options.getCustomAwsCredentialsProvider(); + if (!Strings.isNullOrEmpty(customProviderClass)) { + instantiateProvider(customProviderClass, customAWSCredentialsProvider); + } + + // Set metadataAccessAWSCredentialsProvider to customAWSCredentialsProvider by default for backwards compatibility + // unless a specific metadata access provider class name is provided + String metadataAccessProviderClass = options.getMetadataAccessAwsCredentialsProvider(); + final AtomicReference metadataAccessAWSCredentialsProvider = + new AtomicReference<>(new DefaultAWSCredentialsProviderChain()); + if (Strings.isNullOrEmpty(metadataAccessProviderClass)) { + metadataAccessAWSCredentialsProvider.set(customAWSCredentialsProvider.get()); + } else { + instantiateProvider(metadataAccessProviderClass, metadataAccessAWSCredentialsProvider); } + restClientBuilder.setHttpClientConfigCallback(builder -> { - HttpAsyncClientBuilder delegate = - builder.addInterceptorLast( - new AWSRequestSigningApacheInterceptor( - signer.getServiceName(), signer, awsCredentialsProvider.get())); - return RetryableHttpAsyncClient.builder(delegate, options); - } + HttpAsyncClientBuilder delegate = builder.addInterceptorLast( + new ResourceBasedAWSRequestSigningApacheInterceptor( + signer.getServiceName(), signer, customAWSCredentialsProvider.get(), metadataAccessAWSCredentialsProvider.get(), options.getSystemIndexName())); + return RetryableHttpAsyncClient.builder(delegate, options); + } ); } else if (options.getAuth().equals(FlintOptions.BASIC_AUTH)) { CredentialsProvider credentialsProvider = new BasicCredentialsProvider(); @@ -303,6 +307,20 @@ public IRestHighLevelClient createClient() { return new RestHighLevelClientWrapper(new RestHighLevelClient(restClientBuilder)); } + /** + * Attempts to instantiate the AWS credential provider using reflection. + */ + private void instantiateProvider(String providerClass, AtomicReference provider) { + try { + Class awsCredentialsProviderClass = Class.forName(providerClass); + Constructor ctor = awsCredentialsProviderClass.getDeclaredConstructor(); + ctor.setAccessible(true); + provider.set((AWSCredentialsProvider) ctor.newInstance()); + } catch (Exception e) { + throw new RuntimeException("Failed to instantiate AWSCredentialsProvider: " + providerClass, e); + } + } + /* * Constructs Flint metadata with latest metadata log entry attached if it's available. * It relies on FlintOptions to provide data source name. diff --git a/flint-core/src/test/scala/org/opensearch/flint/core/auth/ResourceBasedAWSRequestSigningApacheInterceptorTest.java b/flint-core/src/test/scala/org/opensearch/flint/core/auth/ResourceBasedAWSRequestSigningApacheInterceptorTest.java new file mode 100644 index 000000000..0ef021b53 --- /dev/null +++ b/flint-core/src/test/scala/org/opensearch/flint/core/auth/ResourceBasedAWSRequestSigningApacheInterceptorTest.java @@ -0,0 +1,66 @@ +package org.opensearch.flint.core.auth; + +import com.amazonaws.auth.AWSCredentialsProvider; +import com.amazonaws.auth.Signer; +import org.apache.http.HttpRequest; +import org.apache.http.message.BasicHttpRequest; +import org.apache.http.protocol.HttpContext; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import static org.mockito.Mockito.*; + +public class ResourceBasedAWSRequestSigningApacheInterceptorTest { + + @Mock + private Signer mockSigner; + @Mock + private AWSCredentialsProvider mockPrimaryCredentialsProvider; + @Mock + private AWSCredentialsProvider mockMetadataAccessCredentialsProvider; + @Mock + private HttpContext mockContext; + @Captor + private ArgumentCaptor httpRequestCaptor; + + private ResourceBasedAWSRequestSigningApacheInterceptor interceptor; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + AWSRequestSigningApacheInterceptor primaryInterceptorSpy = spy(new AWSRequestSigningApacheInterceptor("es", mockSigner, mockPrimaryCredentialsProvider)); + AWSRequestSigningApacheInterceptor metadataInterceptorSpy = spy(new AWSRequestSigningApacheInterceptor("es", mockSigner, mockMetadataAccessCredentialsProvider)); + + interceptor = new ResourceBasedAWSRequestSigningApacheInterceptor( + "es", + primaryInterceptorSpy, + metadataInterceptorSpy, + "/metadata"); + } + + @Test + public void testProcessWithMetadataAccess() throws Exception { + HttpRequest request = new BasicHttpRequest("GET", "/es/metadata/resource"); + + interceptor.process(request, mockContext); + + verify(interceptor.metadataAccessInterceptor).process(httpRequestCaptor.capture(), eq(mockContext)); + verify(interceptor.primaryInterceptor, never()).process(any(HttpRequest.class), any(HttpContext.class)); + assert httpRequestCaptor.getValue().getRequestLine().getUri().contains("/metadata"); + } + + @Test + public void testProcessWithoutMetadataAccess() throws Exception { + HttpRequest request = new BasicHttpRequest("GET", "/es/regular/resource"); + + interceptor.process(request, mockContext); + + verify(interceptor.primaryInterceptor).process(httpRequestCaptor.capture(), eq(mockContext)); + verify(interceptor.metadataAccessInterceptor, never()).process(any(HttpRequest.class), any(HttpContext.class)); + assert !httpRequestCaptor.getValue().getRequestLine().getUri().contains("/metadata"); + } +} diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala index fbbea9176..eb3a29adc 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala @@ -75,7 +75,7 @@ object FlintSparkConf { FlintConfig("spark.datasource.flint.customAWSCredentialsProvider") .datasourceOption() .doc("AWS customAWSCredentialsProvider") - .createWithDefault(FlintOptions.DEFAULT_CUSTOM_AWS_CREDENTIALS_PROVIDER) + .createWithDefault(FlintOptions.DEFAULT_AWS_CREDENTIALS_PROVIDER) val DOC_ID_COLUMN_NAME = FlintConfig("spark.datasource.flint.write.id_name") .datasourceOption() @@ -174,6 +174,10 @@ object FlintSparkConf { FlintConfig(s"spark.flint.job.inactivityLimitMillis") .doc("inactivity timeout") .createWithDefault(String.valueOf(FlintOptions.DEFAULT_INACTIVITY_LIMIT_MILLIS)) + val METADATA_ACCESS_AWS_CREDENTIALS_PROVIDER = + FlintConfig("spark.metadata.accessAWSCredentialsProvider") + .doc("AWS credentials provider for metadata access permission") + .createOptional() } /** @@ -234,6 +238,7 @@ case class FlintSparkConf(properties: JMap[String, String]) extends Serializable DATA_SOURCE_NAME, SESSION_ID, REQUEST_INDEX, + METADATA_ACCESS_AWS_CREDENTIALS_PROVIDER, EXCLUDE_JOB_IDS) .map(conf => (conf.optionKey, conf.readFrom(reader))) .flatMap { diff --git a/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/config/FlintSparkConfSuite.scala b/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/config/FlintSparkConfSuite.scala index 149e8128b..3d643dde3 100644 --- a/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/config/FlintSparkConfSuite.scala +++ b/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/config/FlintSparkConfSuite.scala @@ -9,6 +9,7 @@ import java.util.Optional import scala.collection.JavaConverters._ +import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.http.FlintRetryOptions._ import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper @@ -62,6 +63,18 @@ class FlintSparkConfSuite extends FlintSuite { retryOptions.getRetryableExceptionClassNames.get() shouldBe "java.net.ConnectException" } + test("test metadata access AWS credentials provider option") { + withSparkConf("spark.metadata.accessAWSCredentialsProvider") { + spark.conf.set( + "spark.metadata.accessAWSCredentialsProvider", + "com.example.MetadataAccessCredentialsProvider") + val flintOptions = FlintSparkConf().flintOptions() + assert(flintOptions.getCustomAwsCredentialsProvider == "") + assert( + flintOptions.getMetadataAccessAwsCredentialsProvider == "com.example.MetadataAccessCredentialsProvider") + } + } + /** * Delete index `indexNames` after calling `f`. */ diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala index 8b4bdeeaf..0ac683f7b 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala @@ -56,7 +56,7 @@ object FlintJob extends Logging with FlintJobExecutor { conf.set(FlintSparkConf.JOB_TYPE.key, jobType) val dataSource = conf.get("spark.flint.datasource.name", "") - val query = queryOption.getOrElse(conf.get(FlintSparkConf.QUERY.key, "")) + val query = queryOption.getOrElse(unescapeQuery(conf.get(FlintSparkConf.QUERY.key, ""))) if (query.isEmpty) { throw new IllegalArgumentException(s"Query undefined for the ${jobType} job.") } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index ccd5c8f3f..9cd31208f 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -11,6 +11,7 @@ import scala.concurrent.{ExecutionContext, Future, TimeoutException} import scala.concurrent.duration.{Duration, MINUTES} import com.amazonaws.services.s3.model.AmazonS3Exception +import org.apache.commons.text.StringEscapeUtils.unescapeJava import org.opensearch.flint.core.{FlintClient, IRestHighLevelClient} import org.opensearch.flint.core.metadata.FlintMetadata import org.opensearch.flint.core.metrics.MetricConstants @@ -361,6 +362,14 @@ trait FlintJobExecutor { } } + /** + * Unescape the query string which is escaped for EMR spark submit parameter parsing. Ref: + * https://github.com/opensearch-project/sql/pull/2587 + */ + def unescapeQuery(query: String): String = { + unescapeJava(query) + } + def executeQuery( spark: SparkSession, query: String, @@ -371,6 +380,7 @@ trait FlintJobExecutor { val startTime = System.currentTimeMillis() // we have to set job group in the same thread that started the query according to spark doc spark.sparkContext.setJobGroup(queryId, "Job group for " + queryId, interruptOnCancel = true) + logInfo(s"Executing query: $query") val result: DataFrame = spark.sql(query) // Get Data getFormattedData( diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 76e5f692c..69b655e57 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -251,7 +251,7 @@ object FlintREPL extends Logging with FlintJobExecutor { if (defaultQuery.isEmpty) { throw new IllegalArgumentException("Query undefined for the streaming job.") } - defaultQuery + unescapeQuery(defaultQuery) } else "" } } diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index 421457c4e..288eeb7c5 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -95,6 +95,18 @@ class FlintREPLTest query shouldBe "SELECT * FROM table" } + test( + "getQuery should return unescaped default query for streaming job if queryOption is None") { + val queryOption = None + val jobType = "streaming" + val conf = new SparkConf().set( + FlintSparkConf.QUERY.key, + "SELECT \\\"1\\\" UNION SELECT '\\\"1\\\"' UNION SELECT \\\"\\\\\\\"1\\\\\\\"\\\"") + + val query = FlintREPL.getQuery(queryOption, jobType, conf) + query shouldBe "SELECT \"1\" UNION SELECT '\"1\"' UNION SELECT \"\\\"1\\\"\"" + } + test( "getQuery should throw IllegalArgumentException if queryOption is None and default query is not defined for streaming job") { val queryOption = None