diff --git a/DEVELOPER_GUIDE.md b/DEVELOPER_GUIDE.md index 1aeb89a6c..bb8f697ec 100644 --- a/DEVELOPER_GUIDE.md +++ b/DEVELOPER_GUIDE.md @@ -27,6 +27,7 @@ If you get integration test failures with error message "Previous attempts to fi The `aws-integration` folder contains tests for cloud server providers. For instance, test against AWS OpenSearch domain, configure the following settings. The client will use the default credential provider to access the AWS OpenSearch domain. ``` export AWS_OPENSEARCH_HOST=search-xxx.us-west-2.on.aws +export AWS_OPENSEARCH_SERVERLESS_HOST=xxx.us-west-2.aoss.amazonaws.com export AWS_REGION=us-west-2 export AWS_EMRS_APPID=xxx export AWS_EMRS_EXECUTION_ROLE=xxx diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLog.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLog.java index 7944de5ae..62dd01683 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLog.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchMetadataLog.java @@ -155,7 +155,7 @@ private FlintMetadataLogEntry createLogEntry(FlintMetadataLogEntry logEntry) { new IndexRequest() .index(metadataLogIndexName) .id(logEntryWithId.id()) - .setRefreshPolicy(RefreshPolicy.WAIT_UNTIL) + .setRefreshPolicy(options.getRefreshPolicy()) .source(toJson(logEntryWithId), XContentType.JSON), RequestOptions.DEFAULT)); } @@ -166,7 +166,7 @@ private FlintMetadataLogEntry updateLogEntry(FlintMetadataLogEntry logEntry) { client -> client.update( new UpdateRequest(metadataLogIndexName, logEntry.id()) .doc(toJson(logEntry), XContentType.JSON) - .setRefreshPolicy(RefreshPolicy.WAIT_UNTIL) + .setRefreshPolicy(options.getRefreshPolicy()) .setIfSeqNo((Long) logEntry.entryVersion().get("seqNo").get()) .setIfPrimaryTerm((Long) logEntry.entryVersion().get("primaryTerm").get()), RequestOptions.DEFAULT)); diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchUpdater.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchUpdater.java index 0d84b4956..d9dc54783 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchUpdater.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchUpdater.java @@ -6,6 +6,7 @@ import org.opensearch.client.indices.GetIndexRequest; import org.opensearch.common.xcontent.XContentType; import org.opensearch.flint.core.FlintClient; +import org.opensearch.flint.core.FlintOptions; import org.opensearch.flint.core.IRestHighLevelClient; import java.io.IOException; @@ -25,10 +26,12 @@ public class OpenSearchUpdater { private final String indexName; private final FlintClient flintClient; + private final FlintOptions options; - public OpenSearchUpdater(String indexName, FlintClient flintClient) { + public OpenSearchUpdater(String indexName, FlintClient flintClient, FlintOptions options) { this.indexName = indexName; this.flintClient = flintClient; + this.options = options; } public void upsert(String id, String doc) { @@ -61,7 +64,7 @@ private void updateDocument(String id, String doc, boolean upsert, long seqNo, l assertIndexExist(client, indexName); UpdateRequest updateRequest = new UpdateRequest(indexName, id) .doc(doc, XContentType.JSON) - .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); + .setRefreshPolicy(options.getRefreshPolicy()); if (upsert) { updateRequest.docAsUpsert(true); diff --git a/integ-test/src/aws-integration/scala/org/opensearch/flint/spark/aws/AWSEmrServerlessAccessTestSuite.scala b/integ-test/src/aws-integration/scala/org/opensearch/flint/spark/aws/AWSEmrServerlessAccessTestSuite.scala index 67e036d28..2e599c418 100644 --- a/integ-test/src/aws-integration/scala/org/opensearch/flint/spark/aws/AWSEmrServerlessAccessTestSuite.scala +++ b/integ-test/src/aws-integration/scala/org/opensearch/flint/spark/aws/AWSEmrServerlessAccessTestSuite.scala @@ -5,13 +5,14 @@ package org.opensearch.flint.spark.aws +import java.io.File import java.time.LocalDateTime import scala.concurrent.duration.DurationInt -import com.amazonaws.services.emrserverless.AWSEMRServerlessClientBuilder +import com.amazonaws.services.emrserverless.{AWSEMRServerless, AWSEMRServerlessClientBuilder} import com.amazonaws.services.emrserverless.model.{GetJobRunRequest, JobDriver, SparkSubmit, StartJobRunRequest} -import com.amazonaws.services.s3.AmazonS3ClientBuilder +import com.amazonaws.services.s3.{AmazonS3, AmazonS3ClientBuilder} import org.scalatest.BeforeAndAfter import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers @@ -19,12 +20,13 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.internal.Logging class AWSEmrServerlessAccessTestSuite - extends AnyFlatSpec + extends AnyFlatSpec with BeforeAndAfter with Matchers with Logging { lazy val testHost: String = System.getenv("AWS_OPENSEARCH_HOST") + lazy val testServerlessHost: String = System.getenv("AWS_OPENSEARCH_SERVERLESS_HOST") lazy val testPort: Int = -1 lazy val testRegion: String = System.getenv("AWS_REGION") lazy val testScheme: String = "https" @@ -36,53 +38,38 @@ class AWSEmrServerlessAccessTestSuite lazy val testS3CodePrefix: String = System.getenv("AWS_S3_CODE_PREFIX") lazy val testResultIndex: String = System.getenv("AWS_OPENSEARCH_RESULT_INDEX") - "EMR Serverless job" should "run successfully" in { + "EMR Serverless job with AOS" should "run successfully" in { val s3Client = AmazonS3ClientBuilder.standard().withRegion(testRegion).build() val emrServerless = AWSEMRServerlessClientBuilder.standard().withRegion(testRegion).build() - val appJarPath = - sys.props.getOrElse("appJar", throw new IllegalArgumentException("appJar not set")) - val extensionJarPath = sys.props.getOrElse( - "extensionJar", - throw new IllegalArgumentException("extensionJar not set")) - val pplJarPath = - sys.props.getOrElse("pplJar", throw new IllegalArgumentException("pplJar not set")) + uploadJarsToS3(s3Client) - s3Client.putObject( - testS3CodeBucket, - s"$testS3CodePrefix/sql-job.jar", - new java.io.File(appJarPath)) - s3Client.putObject( - testS3CodeBucket, - s"$testS3CodePrefix/extension.jar", - new java.io.File(extensionJarPath)) - s3Client.putObject( - testS3CodeBucket, - s"$testS3CodePrefix/ppl.jar", - new java.io.File(pplJarPath)) + val jobRunRequest = startJobRun("SELECT 1", testHost, "es") - val jobRunRequest = new StartJobRunRequest() - .withApplicationId(testAppId) - .withExecutionRoleArn(testExecutionRole) - .withName(s"integration-${LocalDateTime.now()}") - .withJobDriver(new JobDriver() - .withSparkSubmit(new SparkSubmit() - .withEntryPoint(s"s3://$testS3CodeBucket/$testS3CodePrefix/sql-job.jar") - .withEntryPointArguments(testResultIndex) - .withSparkSubmitParameters(s"--class org.apache.spark.sql.FlintJob --jars " + - s"s3://$testS3CodeBucket/$testS3CodePrefix/extension.jar," + - s"s3://$testS3CodeBucket/$testS3CodePrefix/ppl.jar " + - s"--conf spark.datasource.flint.host=$testHost " + - s"--conf spark.datasource.flint.port=-1 " + - s"--conf spark.datasource.flint.scheme=$testScheme " + - s"--conf spark.datasource.flint.auth=$testAuth " + - s"--conf spark.sql.catalog.glue=org.opensearch.sql.FlintDelegatingSessionCatalog " + - s"--conf spark.flint.datasource.name=glue " + - s"""--conf spark.flint.job.query="SELECT 1" """ + - s"--conf spark.hadoop.hive.metastore.client.factory.class=com.amazonaws.glue.catalog.metastore.AWSGlueDataCatalogHiveClientFactory"))) + val jobRunResponse = emrServerless.startJobRun(jobRunRequest) + + verifyJobSucceed(emrServerless, jobRunResponse.getJobRunId) + } + + "EMR Serverless job with AOSS" should "run successfully" in { + val s3Client = AmazonS3ClientBuilder.standard().withRegion(testRegion).build() + val emrServerless = AWSEMRServerlessClientBuilder.standard().withRegion(testRegion).build() + + uploadJarsToS3(s3Client) + + val jobRunRequest = startJobRun( + "SELECT 1", + testServerlessHost, + "aoss", + conf("spark.datasource.flint.write.refresh_policy", "false") + ) val jobRunResponse = emrServerless.startJobRun(jobRunRequest) + verifyJobSucceed(emrServerless, jobRunResponse.getJobRunId) + } + + private def verifyJobSucceed(emrServerless: AWSEMRServerless, jobRunId: String): Unit = { val startTime = System.currentTimeMillis() val timeout = 5.minutes.toMillis var jobState = "STARTING" @@ -92,11 +79,72 @@ class AWSEmrServerlessAccessTestSuite Thread.sleep(30000) val request = new GetJobRunRequest() .withApplicationId(testAppId) - .withJobRunId(jobRunResponse.getJobRunId) + .withJobRunId(jobRunId) jobState = emrServerless.getJobRun(request).getJobRun.getState logInfo(s"Current job state: $jobState at ${System.currentTimeMillis()}") } - jobState shouldBe "SUCCESS" } + + private def startJobRun(query: String, host: String, authServiceName: String, additionalParams: String*) = { + new StartJobRunRequest() + .withApplicationId(testAppId) + .withExecutionRoleArn(testExecutionRole) + .withName(s"integration-${authServiceName}-${LocalDateTime.now()}") + .withJobDriver(new JobDriver() + .withSparkSubmit(new SparkSubmit() + .withEntryPoint(s"s3://$testS3CodeBucket/$testS3CodePrefix/sql-job.jar") + .withEntryPointArguments(testResultIndex) + .withSparkSubmitParameters( + join( + clazz("org.apache.spark.sql.FlintJob"), + jars(s"s3://$testS3CodeBucket/$testS3CodePrefix/extension.jar", s"s3://$testS3CodeBucket/$testS3CodePrefix/ppl.jar"), + conf("spark.datasource.flint.host", host), + conf("spark.datasource.flint.port", s"$testPort"), + conf("spark.datasource.flint.scheme", testScheme), + conf("spark.datasource.flint.auth", testAuth), + conf("spark.datasource.flint.auth.servicename", authServiceName), + conf("spark.sql.catalog.glue", "org.opensearch.sql.FlintDelegatingSessionCatalog"), + conf("spark.flint.datasource.name", "glue"), + conf("spark.flint.job.query", quote(query)), + conf("spark.hadoop.hive.metastore.client.factory.class", "com.amazonaws.glue.catalog.metastore.AWSGlueDataCatalogHiveClientFactory"), + join(additionalParams: _*) + ) + ) + ) + ) + } + + private def join(params: String*): String = params.mkString(" ") + + private def clazz(clazz: String): String = s"--class $clazz" + + private def jars(jars: String*): String = s"--jars ${jars.mkString(",")}" + + private def quote(str: String): String = "\"" + str + "\"" + + private def conf(name: String, value: String): String = s"--conf $name=$value" + + private def uploadJarsToS3(s3Client: AmazonS3) = { + val appJarPath = + sys.props.getOrElse("appJar", throw new IllegalArgumentException("appJar not set")) + val extensionJarPath = sys.props.getOrElse( + "extensionJar", + throw new IllegalArgumentException("extensionJar not set")) + val pplJarPath = + sys.props.getOrElse("pplJar", throw new IllegalArgumentException("pplJar not set")) + + s3Client.putObject( + testS3CodeBucket, + s"$testS3CodePrefix/sql-job.jar", + new File(appJarPath)) + s3Client.putObject( + testS3CodeBucket, + s"$testS3CodePrefix/extension.jar", + new File(extensionJarPath)) + s3Client.putObject( + testS3CodeBucket, + s"$testS3CodePrefix/ppl.jar", + new File(pplJarPath)) + } } diff --git a/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala b/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala index d2a43a877..5c101ac2d 100644 --- a/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala +++ b/integ-test/src/integration/scala/org/apache/spark/sql/FlintREPLITSuite.scala @@ -118,10 +118,8 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { flintClient = new FlintOpenSearchClient(new FlintOptions(openSearchOptions.asJava)); osClient = new OSClient(new FlintOptions(openSearchOptions.asJava)) - updater = new OpenSearchUpdater( - requestIndex, - new FlintOpenSearchClient(new FlintOptions(openSearchOptions.asJava))) - + val options = new FlintOptions(openSearchOptions.asJava) + updater = new OpenSearchUpdater(requestIndex, new FlintOpenSearchClient(options), options) } override def afterEach(): Unit = { diff --git a/integ-test/src/integration/scala/org/opensearch/flint/core/OpenSearchUpdaterSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/core/OpenSearchUpdaterSuite.scala index fa7f75b81..e6198496b 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/core/OpenSearchUpdaterSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/core/OpenSearchUpdaterSuite.scala @@ -32,9 +32,8 @@ class OpenSearchUpdaterSuite extends OpenSearchTransactionSuite with Matchers { override def beforeAll(): Unit = { super.beforeAll() flintClient = new FlintOpenSearchClient(new FlintOptions(openSearchOptions.asJava)); - updater = new OpenSearchUpdater( - testMetaLogIndex, - new FlintOpenSearchClient(new FlintOptions(openSearchOptions.asJava))) + val options = new FlintOptions(openSearchOptions.asJava) + updater = new OpenSearchUpdater(testMetaLogIndex, new FlintOpenSearchClient(options), options) } test("upsert flintJob should success") { diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index 00f023694..f9cccf27a 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -129,11 +129,14 @@ trait FlintJobExecutor { builder.getOrCreate() } - private def writeData(resultData: DataFrame, resultIndex: String): Unit = { + private def writeData( + resultData: DataFrame, + resultIndex: String, + refreshPolicy: String): Unit = { try { resultData.write .format("flint") - .option(REFRESH_POLICY.optionKey, "wait_for") + .option(REFRESH_POLICY.optionKey, refreshPolicy) .mode("append") .save(resultIndex) IRestHighLevelClient.recordOperationSuccess( @@ -160,11 +163,12 @@ trait FlintJobExecutor { resultData: DataFrame, resultIndex: String, osClient: OSClient): Unit = { + val refreshPolicy = osClient.flintOptions.getRefreshPolicy; if (osClient.doesIndexExist(resultIndex)) { - writeData(resultData, resultIndex) + writeData(resultData, resultIndex, refreshPolicy) } else { createResultIndex(osClient, resultIndex, resultIndexMapping) - writeData(resultData, resultIndex) + writeData(resultData, resultIndex, refreshPolicy) } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala index 422cfc947..ebac04876 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala @@ -111,7 +111,7 @@ class OSClient(val flintOptions: FlintOptions) extends Logging { } def createUpdater(indexName: String): OpenSearchUpdater = - new OpenSearchUpdater(indexName, flintClient) + new OpenSearchUpdater(indexName, flintClient, flintOptions) def getDoc(osIndexName: String, id: String): GetResponse = { using(flintClient.createClient()) { client =>