Skip to content

Commit

Permalink
Use refresh policy from config
Browse files Browse the repository at this point in the history
Signed-off-by: Tomoyuki Morita <[email protected]>
  • Loading branch information
ykmr1224 committed Aug 7, 2024
1 parent 5be9be6 commit a08da93
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 56 deletions.
1 change: 1 addition & 0 deletions DEVELOPER_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ If you get integration test failures with error message "Previous attempts to fi
The `aws-integration` folder contains tests for cloud server providers. For instance, test against AWS OpenSearch domain, configure the following settings. The client will use the default credential provider to access the AWS OpenSearch domain.
```
export AWS_OPENSEARCH_HOST=search-xxx.us-west-2.on.aws
export AWS_OPENSEARCH_SERVERLESS_HOST=xxx.us-west-2.aoss.amazonaws.com
export AWS_REGION=us-west-2
export AWS_EMRS_APPID=xxx
export AWS_EMRS_EXECUTION_ROLE=xxx
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ private FlintMetadataLogEntry createLogEntry(FlintMetadataLogEntry logEntry) {
new IndexRequest()
.index(metadataLogIndexName)
.id(logEntryWithId.id())
.setRefreshPolicy(RefreshPolicy.WAIT_UNTIL)
.setRefreshPolicy(options.getRefreshPolicy())
.source(toJson(logEntryWithId), XContentType.JSON),
RequestOptions.DEFAULT));
}
Expand All @@ -166,7 +166,7 @@ private FlintMetadataLogEntry updateLogEntry(FlintMetadataLogEntry logEntry) {
client -> client.update(
new UpdateRequest(metadataLogIndexName, logEntry.id())
.doc(toJson(logEntry), XContentType.JSON)
.setRefreshPolicy(RefreshPolicy.WAIT_UNTIL)
.setRefreshPolicy(options.getRefreshPolicy())
.setIfSeqNo((Long) logEntry.entryVersion().get("seqNo").get())
.setIfPrimaryTerm((Long) logEntry.entryVersion().get("primaryTerm").get()),
RequestOptions.DEFAULT));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import org.opensearch.client.indices.GetIndexRequest;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.flint.core.FlintClient;
import org.opensearch.flint.core.FlintOptions;
import org.opensearch.flint.core.IRestHighLevelClient;

import java.io.IOException;
Expand All @@ -25,10 +26,12 @@ public class OpenSearchUpdater {

private final String indexName;
private final FlintClient flintClient;
private final FlintOptions options;

public OpenSearchUpdater(String indexName, FlintClient flintClient) {
public OpenSearchUpdater(String indexName, FlintClient flintClient, FlintOptions options) {
this.indexName = indexName;
this.flintClient = flintClient;
this.options = options;
}

public void upsert(String id, String doc) {
Expand Down Expand Up @@ -61,7 +64,7 @@ private void updateDocument(String id, String doc, boolean upsert, long seqNo, l
assertIndexExist(client, indexName);
UpdateRequest updateRequest = new UpdateRequest(indexName, id)
.doc(doc, XContentType.JSON)
.setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL);
.setRefreshPolicy(options.getRefreshPolicy());

if (upsert) {
updateRequest.docAsUpsert(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,28 @@

package org.opensearch.flint.spark.aws

import java.io.File
import java.time.LocalDateTime

import scala.concurrent.duration.DurationInt

import com.amazonaws.services.emrserverless.AWSEMRServerlessClientBuilder
import com.amazonaws.services.emrserverless.{AWSEMRServerless, AWSEMRServerlessClientBuilder}
import com.amazonaws.services.emrserverless.model.{GetJobRunRequest, JobDriver, SparkSubmit, StartJobRunRequest}
import com.amazonaws.services.s3.AmazonS3ClientBuilder
import com.amazonaws.services.s3.{AmazonS3, AmazonS3ClientBuilder}
import org.scalatest.BeforeAndAfter
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

import org.apache.spark.internal.Logging

class AWSEmrServerlessAccessTestSuite
extends AnyFlatSpec
extends AnyFlatSpec
with BeforeAndAfter
with Matchers
with Logging {

lazy val testHost: String = System.getenv("AWS_OPENSEARCH_HOST")
lazy val testServerlessHost: String = System.getenv("AWS_OPENSEARCH_SERVERLESS_HOST")
lazy val testPort: Int = -1
lazy val testRegion: String = System.getenv("AWS_REGION")
lazy val testScheme: String = "https"
Expand All @@ -36,53 +38,38 @@ class AWSEmrServerlessAccessTestSuite
lazy val testS3CodePrefix: String = System.getenv("AWS_S3_CODE_PREFIX")
lazy val testResultIndex: String = System.getenv("AWS_OPENSEARCH_RESULT_INDEX")

"EMR Serverless job" should "run successfully" in {
"EMR Serverless job with AOS" should "run successfully" in {
val s3Client = AmazonS3ClientBuilder.standard().withRegion(testRegion).build()
val emrServerless = AWSEMRServerlessClientBuilder.standard().withRegion(testRegion).build()

val appJarPath =
sys.props.getOrElse("appJar", throw new IllegalArgumentException("appJar not set"))
val extensionJarPath = sys.props.getOrElse(
"extensionJar",
throw new IllegalArgumentException("extensionJar not set"))
val pplJarPath =
sys.props.getOrElse("pplJar", throw new IllegalArgumentException("pplJar not set"))
uploadJarsToS3(s3Client)

s3Client.putObject(
testS3CodeBucket,
s"$testS3CodePrefix/sql-job.jar",
new java.io.File(appJarPath))
s3Client.putObject(
testS3CodeBucket,
s"$testS3CodePrefix/extension.jar",
new java.io.File(extensionJarPath))
s3Client.putObject(
testS3CodeBucket,
s"$testS3CodePrefix/ppl.jar",
new java.io.File(pplJarPath))
val jobRunRequest = startJobRun("SELECT 1", testHost, "es")

val jobRunRequest = new StartJobRunRequest()
.withApplicationId(testAppId)
.withExecutionRoleArn(testExecutionRole)
.withName(s"integration-${LocalDateTime.now()}")
.withJobDriver(new JobDriver()
.withSparkSubmit(new SparkSubmit()
.withEntryPoint(s"s3://$testS3CodeBucket/$testS3CodePrefix/sql-job.jar")
.withEntryPointArguments(testResultIndex)
.withSparkSubmitParameters(s"--class org.apache.spark.sql.FlintJob --jars " +
s"s3://$testS3CodeBucket/$testS3CodePrefix/extension.jar," +
s"s3://$testS3CodeBucket/$testS3CodePrefix/ppl.jar " +
s"--conf spark.datasource.flint.host=$testHost " +
s"--conf spark.datasource.flint.port=-1 " +
s"--conf spark.datasource.flint.scheme=$testScheme " +
s"--conf spark.datasource.flint.auth=$testAuth " +
s"--conf spark.sql.catalog.glue=org.opensearch.sql.FlintDelegatingSessionCatalog " +
s"--conf spark.flint.datasource.name=glue " +
s"""--conf spark.flint.job.query="SELECT 1" """ +
s"--conf spark.hadoop.hive.metastore.client.factory.class=com.amazonaws.glue.catalog.metastore.AWSGlueDataCatalogHiveClientFactory")))
val jobRunResponse = emrServerless.startJobRun(jobRunRequest)

verifyJobSucceed(emrServerless, jobRunResponse.getJobRunId)
}

"EMR Serverless job with AOSS" should "run successfully" in {
val s3Client = AmazonS3ClientBuilder.standard().withRegion(testRegion).build()
val emrServerless = AWSEMRServerlessClientBuilder.standard().withRegion(testRegion).build()

uploadJarsToS3(s3Client)

val jobRunRequest = startJobRun(
"SELECT 1",
testServerlessHost,
"aoss",
conf("spark.datasource.flint.write.refresh_policy", "false")
)

val jobRunResponse = emrServerless.startJobRun(jobRunRequest)

verifyJobSucceed(emrServerless, jobRunResponse.getJobRunId)
}

private def verifyJobSucceed(emrServerless: AWSEMRServerless, jobRunId: String): Unit = {
val startTime = System.currentTimeMillis()
val timeout = 5.minutes.toMillis
var jobState = "STARTING"
Expand All @@ -92,11 +79,72 @@ class AWSEmrServerlessAccessTestSuite
Thread.sleep(30000)
val request = new GetJobRunRequest()
.withApplicationId(testAppId)
.withJobRunId(jobRunResponse.getJobRunId)
.withJobRunId(jobRunId)
jobState = emrServerless.getJobRun(request).getJobRun.getState
logInfo(s"Current job state: $jobState at ${System.currentTimeMillis()}")
}

jobState shouldBe "SUCCESS"
}

private def startJobRun(query: String, host: String, authServiceName: String, additionalParams: String*) = {
new StartJobRunRequest()
.withApplicationId(testAppId)
.withExecutionRoleArn(testExecutionRole)
.withName(s"integration-${authServiceName}-${LocalDateTime.now()}")
.withJobDriver(new JobDriver()
.withSparkSubmit(new SparkSubmit()
.withEntryPoint(s"s3://$testS3CodeBucket/$testS3CodePrefix/sql-job.jar")
.withEntryPointArguments(testResultIndex)
.withSparkSubmitParameters(
join(
clazz("org.apache.spark.sql.FlintJob"),
jars(s"s3://$testS3CodeBucket/$testS3CodePrefix/extension.jar", s"s3://$testS3CodeBucket/$testS3CodePrefix/ppl.jar"),
conf("spark.datasource.flint.host", host),
conf("spark.datasource.flint.port", s"$testPort"),
conf("spark.datasource.flint.scheme", testScheme),
conf("spark.datasource.flint.auth", testAuth),
conf("spark.datasource.flint.auth.servicename", authServiceName),
conf("spark.sql.catalog.glue", "org.opensearch.sql.FlintDelegatingSessionCatalog"),
conf("spark.flint.datasource.name", "glue"),
conf("spark.flint.job.query", quote(query)),
conf("spark.hadoop.hive.metastore.client.factory.class", "com.amazonaws.glue.catalog.metastore.AWSGlueDataCatalogHiveClientFactory"),
join(additionalParams: _*)
)
)
)
)
}

private def join(params: String*): String = params.mkString(" ")

private def clazz(clazz: String): String = s"--class $clazz"

private def jars(jars: String*): String = s"--jars ${jars.mkString(",")}"

private def quote(str: String): String = "\"" + str + "\""

private def conf(name: String, value: String): String = s"--conf $name=$value"

private def uploadJarsToS3(s3Client: AmazonS3) = {
val appJarPath =
sys.props.getOrElse("appJar", throw new IllegalArgumentException("appJar not set"))
val extensionJarPath = sys.props.getOrElse(
"extensionJar",
throw new IllegalArgumentException("extensionJar not set"))
val pplJarPath =
sys.props.getOrElse("pplJar", throw new IllegalArgumentException("pplJar not set"))

s3Client.putObject(
testS3CodeBucket,
s"$testS3CodePrefix/sql-job.jar",
new File(appJarPath))
s3Client.putObject(
testS3CodeBucket,
s"$testS3CodePrefix/extension.jar",
new File(extensionJarPath))
s3Client.putObject(
testS3CodeBucket,
s"$testS3CodePrefix/ppl.jar",
new File(pplJarPath))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,12 @@ class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest {

flintClient = new FlintOpenSearchClient(new FlintOptions(openSearchOptions.asJava));
osClient = new OSClient(new FlintOptions(openSearchOptions.asJava))
val options = new FlintOptions(openSearchOptions.asJava)
updater = new OpenSearchUpdater(
requestIndex,
new FlintOpenSearchClient(new FlintOptions(openSearchOptions.asJava)))

new FlintOpenSearchClient(options),
options
)
}

override def afterEach(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ class OpenSearchUpdaterSuite extends OpenSearchTransactionSuite with Matchers {
override def beforeAll(): Unit = {
super.beforeAll()
flintClient = new FlintOpenSearchClient(new FlintOptions(openSearchOptions.asJava));
val options = new FlintOptions(openSearchOptions.asJava)
updater = new OpenSearchUpdater(
testMetaLogIndex,
new FlintOpenSearchClient(new FlintOptions(openSearchOptions.asJava)))
new FlintOpenSearchClient(options),
options)
}

test("upsert flintJob should success") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,14 @@ trait FlintJobExecutor {
builder.getOrCreate()
}

private def writeData(resultData: DataFrame, resultIndex: String): Unit = {
private def writeData(
resultData: DataFrame,
resultIndex: String,
refreshPolicy: String): Unit = {
try {
resultData.write
.format("flint")
.option(REFRESH_POLICY.optionKey, "wait_for")
.option(REFRESH_POLICY.optionKey, refreshPolicy)
.mode("append")
.save(resultIndex)
IRestHighLevelClient.recordOperationSuccess(
Expand All @@ -160,11 +163,12 @@ trait FlintJobExecutor {
resultData: DataFrame,
resultIndex: String,
osClient: OSClient): Unit = {
val refreshPolicy = osClient.flintOptions.getRefreshPolicy;
if (osClient.doesIndexExist(resultIndex)) {
writeData(resultData, resultIndex)
writeData(resultData, resultIndex, refreshPolicy)
} else {
createResultIndex(osClient, resultIndex, resultIndexMapping)
writeData(resultData, resultIndex)
writeData(resultData, resultIndex, refreshPolicy)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class OSClient(val flintOptions: FlintOptions) extends Logging {
}

def createUpdater(indexName: String): OpenSearchUpdater =
new OpenSearchUpdater(indexName, flintClient)
new OpenSearchUpdater(indexName, flintClient, flintOptions)

def getDoc(osIndexName: String, id: String): GetResponse = {
using(flintClient.createClient()) { client =>
Expand Down

0 comments on commit a08da93

Please sign in to comment.