From b8c79290e7abbc830aea8fed715bc5d8c5db2c6a Mon Sep 17 00:00:00 2001 From: Kaituo Li Date: Mon, 22 Jan 2024 15:48:46 -0800 Subject: [PATCH] Fix Session state bug and improve Query Efficiency in REPL This PR introduces a bug fix and enhancements to FlintREPL's session management and optimizes query execution methods. It addresses a specific issue where marking a session as 'dead' inadvertently triggered the creation of a new, unnecessary session. This behavior resulted in the new session entering a spin-wait state, leading to duplicate jobs. The improvements include: - **Introduction of `earlyExitFlag`**: A new flag, `earlyExitFlag`, has been introduced and is set to `true` under two conditions: when a job is excluded or when it is not associated with the current session's job run ID. This flag is evaluated in the shutdown hook to determine whether the session state should be marked as 'dead'. This change effectively prevents the unintended creation of duplicate sessions by the SQL plugin, ensuring resources are utilized more efficiently. - **Query Method Optimization**: The method for executing queries has been shifted from scrolling to search, eliminating the need for creating unnecessary scroll contexts. This adjustment enhances the performance and efficiency of query operations. - **Reversion of Previous Commit**: The PR reverts a previous change (https://github.com/opensearch-project/opensearch-spark/commit/be8202480be851f2aba6d13a2edb12a379b0cd58) following the resolution of the related issue in the SQL plugin (https://github.com/opensearch-project/sql/issues/2436), further streamlining the operation and maintenance of the system. **Testing**: 1. Integration tests were added to cover both REPL and streaming job functionalities, ensuring the robustness of the fixes. 2. Manual testing was conducted to validate the bug fix. Signed-off-by: Kaituo Li --- build.sbt | 9 +- .../opensearch/flint/core/FlintOptions.java | 2 + .../core/storage/OpenSearchQueryReader.java | 48 ++ .../flint/core/storage/OpenSearchReader.java | 8 + .../sql/flint/config/FlintSparkConf.scala | 36 +- .../opensearch/flint/app/FlintInstance.scala | 9 +- .../apache/spark/sql/FlintJobITSuite.scala | 259 ++++++++ .../apache/spark/sql/FlintREPLITSuite.scala | 573 ++++++++++++++++++ .../scala/org/apache/spark/sql/JobTest.scala | 88 +++ .../org/apache/spark/sql/REPLResult.scala | 53 ++ .../flint/spark/FlintSparkSuite.scala | 1 + .../ppl/FlintSparkPPLCorrelationITSuite.scala | 3 +- .../ppl/FlintSparkPPLFiltersITSuite.scala | 1 + .../scala/org/apache/spark/sql/FlintJob.scala | 8 +- .../apache/spark/sql/FlintJobExecutor.scala | 22 +- .../org/apache/spark/sql/FlintREPL.scala | 56 +- .../org/apache/spark/sql/JobOperator.scala | 3 +- .../scala/org/apache/spark/sql/OSClient.scala | 29 +- .../spark/sql/util/EnvironmentProvider.scala | 10 + .../spark/sql/util/RealEnvironment.scala | 10 + .../org/apache/spark/sql/FlintREPLTest.scala | 25 +- .../spark/sql/util/MockEnvironment.scala | 10 + 22 files changed, 1212 insertions(+), 51 deletions(-) create mode 100644 flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchQueryReader.java create mode 100644 integ-test/src/test/scala/org/apache/spark/sql/FlintJobITSuite.scala create mode 100644 integ-test/src/test/scala/org/apache/spark/sql/FlintREPLITSuite.scala create mode 100644 integ-test/src/test/scala/org/apache/spark/sql/JobTest.scala create mode 100644 integ-test/src/test/scala/org/apache/spark/sql/REPLResult.scala create mode 100644 spark-sql-application/src/main/scala/org/apache/spark/sql/util/EnvironmentProvider.scala create mode 100644 spark-sql-application/src/main/scala/org/apache/spark/sql/util/RealEnvironment.scala create mode 100644 spark-sql-application/src/test/scala/org/apache/spark/sql/util/MockEnvironment.scala diff --git a/build.sbt b/build.sbt index 48e4bca5b..3c6281684 100644 --- a/build.sbt +++ b/build.sbt @@ -42,8 +42,9 @@ lazy val commonSettings = Seq( testScalastyle := (Test / scalastyle).toTask("").value, Test / test := ((Test / test) dependsOn testScalastyle).value) +// running `scalafmtAll` includes all subprojects under root lazy val root = (project in file(".")) - .aggregate(flintCore, flintSparkIntegration, pplSparkIntegration, sparkSqlApplication) + .aggregate(flintCore, flintSparkIntegration, pplSparkIntegration, sparkSqlApplication, integtest) .disablePlugins(AssemblyPlugin) .settings(name := "flint", publish / skip := true) @@ -159,7 +160,7 @@ lazy val flintSparkIntegration = (project in file("flint-spark-integration")) // Test assembly package with integration test. lazy val integtest = (project in file("integ-test")) - .dependsOn(flintSparkIntegration % "test->test", pplSparkIntegration % "test->test" ) + .dependsOn(flintSparkIntegration % "test->test", pplSparkIntegration % "test->test", sparkSqlApplication % "test->test") .settings( commonSettings, name := "integ-test", @@ -175,7 +176,9 @@ lazy val integtest = (project in file("integ-test")) "org.opensearch.client" % "opensearch-java" % "2.6.0" % "test" exclude ("com.fasterxml.jackson.core", "jackson-databind")), libraryDependencies ++= deps(sparkVersion), - Test / fullClasspath ++= Seq((flintSparkIntegration / assembly).value, (pplSparkIntegration / assembly).value)) + Test / fullClasspath ++= Seq((flintSparkIntegration / assembly).value, (pplSparkIntegration / assembly).value, + (sparkSqlApplication / assembly).value + )) lazy val standaloneCosmetic = project .settings( diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java b/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java index c1c5491ed..410d896d2 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java @@ -77,6 +77,8 @@ public class FlintOptions implements Serializable { public static final int DEFAULT_SOCKET_TIMEOUT_MILLIS = 60000; + public static final int DEFAULT_INACTIVITY_LIMIT_MILLIS = 10 * 60 * 1000; + public FlintOptions(Map options) { this.options = options; this.retryOptions = new FlintRetryOptions(options); diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchQueryReader.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchQueryReader.java new file mode 100644 index 000000000..349e5c126 --- /dev/null +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchQueryReader.java @@ -0,0 +1,48 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.storage; + +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.search.ClearScrollRequest; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchScrollRequest; +import org.opensearch.client.RequestOptions; +import org.opensearch.client.RestHighLevelClient; +import org.opensearch.common.Strings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.flint.core.FlintOptions; +import org.opensearch.flint.core.IRestHighLevelClient; +import org.opensearch.search.builder.SearchSourceBuilder; + +import java.io.IOException; +import java.util.Optional; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * {@link OpenSearchReader} using search. https://opensearch.org/docs/latest/api-reference/search/ + */ +public class OpenSearchQueryReader extends OpenSearchReader { + + private static final Logger LOG = Logger.getLogger(OpenSearchQueryReader.class.getName()); + + public OpenSearchQueryReader(IRestHighLevelClient client, String indexName, SearchSourceBuilder searchSourceBuilder) { + super(client, new SearchRequest().indices(indexName).source(searchSourceBuilder)); + } + + /** + * search. + */ + Optional search(SearchRequest request) throws IOException { + return Optional.of(client.search(request, RequestOptions.DEFAULT)); + } + + /** + * nothing to clean + */ + void clean() throws IOException {} +} diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchReader.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchReader.java index c70d327fe..e2e831bd0 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchReader.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchReader.java @@ -5,6 +5,7 @@ package org.opensearch.flint.core.storage; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.flint.core.IRestHighLevelClient; @@ -48,6 +49,13 @@ public OpenSearchReader(IRestHighLevelClient client, SearchRequest searchRequest iterator = searchHits.iterator(); } return iterator.hasNext(); + } catch (OpenSearchStatusException e) { + // e.g., org.opensearch.OpenSearchStatusException: OpenSearch exception [type=index_not_found_exception, reason=no such index [query_results2]] + if (e.getMessage() != null && (e.getMessage().contains("index_not_found_exception"))) { + return false; + } else { + throw e; + } } catch (IOException e) { // todo. log error. throw new RuntimeException(e); diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala index fd998d46d..359994c56 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala @@ -146,7 +146,30 @@ object FlintSparkConf { .datasourceOption() .doc("socket duration in milliseconds") .createWithDefault(String.valueOf(FlintOptions.DEFAULT_SOCKET_TIMEOUT_MILLIS)) - + val DATA_SOURCE_NAME = + FlintConfig(s"spark.flint.datasource.name") + .doc("data source name") + .createOptional() + val JOB_TYPE = + FlintConfig(s"spark.flint.job.type") + .doc("Flint job type. Including interactive and streaming") + .createWithDefault("interactive") + val SESSION_ID = + FlintConfig(s"spark.flint.job.sessionId") + .doc("Flint session id") + .createOptional() + val REQUEST_INDEX = + FlintConfig(s"spark.flint.job.requestIndex") + .doc("Request index") + .createOptional() + val EXCLUDE_JOB_IDS = + FlintConfig(s"spark.flint.deployment.excludeJobs") + .doc("Exclude job ids") + .createOptional() + val REPL_INACTIVITY_TIMEOUT_MILLIS = + FlintConfig(s"spark.flint.job.inactivityLimitMillis") + .doc("inactivity timeout") + .createWithDefault(String.valueOf(FlintOptions.DEFAULT_INACTIVITY_LIMIT_MILLIS)) } /** @@ -196,11 +219,18 @@ case class FlintSparkConf(properties: JMap[String, String]) extends Serializable CUSTOM_AWS_CREDENTIALS_PROVIDER, USERNAME, PASSWORD, - SOCKET_TIMEOUT_MILLIS) + SOCKET_TIMEOUT_MILLIS, + JOB_TYPE, + REPL_INACTIVITY_TIMEOUT_MILLIS) .map(conf => (conf.optionKey, conf.readFrom(reader))) .toMap - val optionsWithoutDefault = Seq(RETRYABLE_EXCEPTION_CLASS_NAMES) + val optionsWithoutDefault = Seq( + RETRYABLE_EXCEPTION_CLASS_NAMES, + DATA_SOURCE_NAME, + SESSION_ID, + REQUEST_INDEX, + EXCLUDE_JOB_IDS) .map(conf => (conf.optionKey, conf.readFrom(reader))) .flatMap { case (_, None) => None diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala index 5af70b793..9911a3b6c 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/app/FlintInstance.scala @@ -25,7 +25,14 @@ class FlintInstance( val lastUpdateTime: Long, val jobStartTime: Long = 0, val excludedJobIds: Seq[String] = Seq.empty[String], - val error: Option[String] = None) {} + val error: Option[String] = None) { + override def toString: String = { + val excludedJobIdsStr = excludedJobIds.mkString("[", ", ", "]") + val errorStr = error.getOrElse("None") + s"FlintInstance(applicationId=$applicationId, jobId=$jobId, sessionId=$sessionId, state=$state, " + + s"lastUpdateTime=$lastUpdateTime, jobStartTime=$jobStartTime, excludedJobIds=$excludedJobIdsStr, error=$errorStr)" + } +} object FlintInstance { diff --git a/integ-test/src/test/scala/org/apache/spark/sql/FlintJobITSuite.scala b/integ-test/src/test/scala/org/apache/spark/sql/FlintJobITSuite.scala new file mode 100644 index 000000000..046292316 --- /dev/null +++ b/integ-test/src/test/scala/org/apache/spark/sql/FlintJobITSuite.scala @@ -0,0 +1,259 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import scala.collection.JavaConverters._ +import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.duration.{Duration, MINUTES} +import scala.util.{Failure, Success} +import org.opensearch.flint.core.FlintOptions +import org.opensearch.flint.spark.FlintSparkSuite +import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName +import org.scalatest.matchers.must.Matchers.{defined, have} +import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, the} +import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE +import org.apache.spark.sql.util.MockEnvironment +import org.apache.spark.util.ThreadUtils + +import scala.util.control.Breaks.{break, breakable} + +class FlintJobITSuite extends FlintSparkSuite with JobTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.skipping_sql_test" + private val testIndex = getSkippingIndexName(testTable) + val resultIndex = "query_results2" + val appId = "00feq82b752mbt0p" + val dataSourceName = "my_glue1" + var osClient: OSClient = _ + val threadLocalFuture = new ThreadLocal[Future[Unit]]() + + override def beforeAll(): Unit = { + super.beforeAll() + // initialized after the container is started + osClient = new OSClient(new FlintOptions(openSearchOptions.asJava)) + createPartitionedMultiRowTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + + deleteTestIndex(testIndex) + + waitJobStop(threadLocalFuture.get()) + + threadLocalFuture.remove() + } + + def waitJobStop(future: Future[Unit]): Unit = { + try { + val activeJob = spark.streams.active.find(_.name == testIndex) + if (activeJob.isDefined) { + activeJob.get.stop() + } + ThreadUtils.awaitResult(future, Duration(1, MINUTES)) + } catch { + case e: Exception => + e.printStackTrace() + assert(false, "failure waiting for job to finish") + } + } + + def startJob(query: String, jobRunId: String): Future[Unit] = { + val prefix = "flint-job-test" + val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor(prefix, 1) + implicit val executionContext = ExecutionContext.fromExecutor(threadPool) + + val futureResult = Future { + val job = + JobOperator(spark, query, dataSourceName, resultIndex, true) + job.envinromentProvider = new MockEnvironment( + Map("SERVERLESS_EMR_JOB_ID" -> jobRunId, "SERVERLESS_EMR_VIRTUAL_CLUSTER_ID" -> appId)) + + job.start() + } + futureResult.onComplete { + case Success(result) => logInfo(s"Success result: $result") + case Failure(ex) => + ex.printStackTrace() + assert(false, s"An error has occurred: ${ex.getMessage}") + } + futureResult + } + + test("create skipping index with auto refresh") { + val query = + s""" + | CREATE SKIPPING INDEX ON $testTable + | ( + | year PARTITION, + | name VALUE_SET, + | age MIN_MAX + | ) + | WITH (auto_refresh = true) + | """.stripMargin + val queryStartTime = System.currentTimeMillis() + val jobRunId = "00ff4o3b5091080q" + threadLocalFuture.set(startJob(query, jobRunId)) + + val validation: REPLResult => Boolean = result => { + assert( + result.results.size == 0, + s"expected result size is 0, but got ${result.results.size}") + assert( + result.schemas.size == 0, + s"expected schema size is 0, but got ${result.schemas.size}") + + assert(result.status == "SUCCESS", s"expected status is SUCCESS, but got ${result.status}") + assert(result.error.isEmpty, s"we don't expect error, but got ${result.error}") + assert(result.queryId.isEmpty, s"we don't expect query id, but got ${result.queryId}") + + commonAssert(result, jobRunId, query, queryStartTime) + true + } + pollForResultAndAssert(validation, jobRunId) + + val activeJob = spark.streams.active.find(_.name == testIndex) + activeJob shouldBe defined + failAfter(streamingTimeout) { + activeJob.get.processAllAvailable() + } + val indexData = spark.read.format(FLINT_DATASOURCE).load(testIndex) + flint.describeIndex(testIndex) shouldBe defined + indexData.count() shouldBe 2 + } + + test("create skipping index with non-existent table") { + val query = + s""" + | CREATE SKIPPING INDEX ON testTable + | ( + | year PARTITION, + | name VALUE_SET, + | age MIN_MAX + | ) + | WITH (auto_refresh = true) + | """.stripMargin + val queryStartTime = System.currentTimeMillis() + val jobRunId = "00ff4o3b5091080r" + threadLocalFuture.set(startJob(query, jobRunId)) + + val validation: REPLResult => Boolean = result => { + assert( + result.results.size == 0, + s"expected result size is 0, but got ${result.results.size}") + assert( + result.schemas.size == 0, + s"expected schema size is 0, but got ${result.schemas.size}") + + assert(result.status == "FAILED", s"expected status is FAILED, but got ${result.status}") + assert(!result.error.isEmpty, s"we expect error, but got ${result.error}") + commonAssert(result, jobRunId, query, queryStartTime) + true + } + pollForResultAndAssert(validation, jobRunId) + } + + test("describe skipping index") { + flint + .skippingIndex() + .onTable(testTable) + .addPartitions("year") + .addValueSet("name") + .addMinMax("age") + .create() + + val queryStartTime = System.currentTimeMillis() + val jobRunId = "00ff4o3b5091080s" + val query = s"DESC SKIPPING INDEX ON $testTable" + threadLocalFuture.set(startJob(query, jobRunId)) + + val validation: REPLResult => Boolean = result => { + assert( + result.results.size == 3, + s"expected result size is 3, but got ${result.results.size}") + val expectedResult0 = + "{'indexed_col_name':'year','data_type':'int','skip_type':'PARTITION'}" + assert( + result.results(0) == expectedResult0, + s"expected result size is $expectedResult0, but got ${result.results(0)}") + val expectedResult1 = + "{'indexed_col_name':'name','data_type':'string','skip_type':'VALUE_SET'}" + assert( + result.results(1) == expectedResult1, + s"expected result size is $expectedResult1, but got ${result.results(1)}") + val expectedResult2 = "{'indexed_col_name':'age','data_type':'int','skip_type':'MIN_MAX'}" + assert( + result.results(2) == expectedResult2, + s"expected result size is $expectedResult2, but got ${result.results(2)}") + assert( + result.schemas.size == 3, + s"expected schema size is 3, but got ${result.schemas.size}") + val expectedZerothSchema = "{'column_name':'indexed_col_name','data_type':'string'}" + assert( + result.schemas(0).equals(expectedZerothSchema), + s"expected 0th field is $expectedZerothSchema, but got ${result.schemas(0)}") + val expectedFirstSchema = "{'column_name':'data_type','data_type':'string'}" + assert( + result.schemas(1).equals(expectedFirstSchema), + s"expected 1st field is $expectedFirstSchema, but got ${result.schemas(1)}") + val expectedSecondSchema = "{'column_name':'skip_type','data_type':'string'}" + assert( + result.schemas(2).equals(expectedSecondSchema), + s"expected 2nd field is $expectedSecondSchema, but got ${result.schemas(2)}") + + assert(result.status == "SUCCESS", s"expected status is FAILED, but got ${result.status}") + assert(result.error.isEmpty, s"we expect error, but got ${result.error}") + + commonAssert(result, jobRunId, query, queryStartTime) + true + } + pollForResultAndAssert(validation, jobRunId) + } + + def commonAssert( + result: REPLResult, + jobRunId: String, + query: String, + queryStartTime: Long): Unit = { + assert( + result.jobRunId == jobRunId, + s"expected jobRunId is $jobRunId, but got ${result.jobRunId}") + assert( + result.applicationId == appId, + s"expected applicationId is $appId, but got ${result.applicationId}") + assert( + result.dataSourceName == dataSourceName, + s"expected data source is $dataSourceName, but got ${result.dataSourceName}") + val actualQueryText = normalizeString(result.queryText) + val expectedQueryText = normalizeString(query) + assert( + actualQueryText == expectedQueryText, + s"expected query is $expectedQueryText, but got $actualQueryText") + assert(result.sessionId.isEmpty, s"we don't expect session id, but got ${result.sessionId}") + assert( + result.updateTime > queryStartTime, + s"expect that update time is ${result.updateTime} later than query start time $queryStartTime, but it is not") + assert( + result.queryRunTime > 0, + s"expected query run time is positive, but got ${result.queryRunTime}") + assert( + result.queryRunTime < System.currentTimeMillis() - queryStartTime, + s"expected query run time ${result.queryRunTime} should be less than ${System + .currentTimeMillis() - queryStartTime}, but it is not") + assert(result.queryId.isEmpty, s"we don't expect query id, but got ${result.queryId}") + } + + def pollForResultAndAssert(expected: REPLResult => Boolean, jobId: String): Unit = { + pollForResultAndAssert( + osClient, + expected, + "jobRunId", + jobId, + streamingTimeout.toMillis, + resultIndex) + } +} diff --git a/integ-test/src/test/scala/org/apache/spark/sql/FlintREPLITSuite.scala b/integ-test/src/test/scala/org/apache/spark/sql/FlintREPLITSuite.scala new file mode 100644 index 000000000..9a2afc71e --- /dev/null +++ b/integ-test/src/test/scala/org/apache/spark/sql/FlintREPLITSuite.scala @@ -0,0 +1,573 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import java.util.UUID + +import scala.collection.JavaConverters._ +import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.duration.{Duration, MINUTES} +import scala.util.{Failure, Success, Try} +import scala.util.control.Breaks.{break, breakable} + +import org.opensearch.OpenSearchStatusException +import org.opensearch.flint.OpenSearchSuite +import org.opensearch.flint.app.{FlintCommand, FlintInstance} +import org.opensearch.flint.core.{FlintClient, FlintOptions} +import org.opensearch.flint.core.storage.{FlintOpenSearchClient, FlintReader, OpenSearchUpdater} +import org.opensearch.search.sort.SortOrder + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.flint.config.FlintSparkConf.{DATA_SOURCE_NAME, EXCLUDE_JOB_IDS, HOST_ENDPOINT, HOST_PORT, JOB_TYPE, REFRESH_POLICY, REPL_INACTIVITY_TIMEOUT_MILLIS, REQUEST_INDEX, SESSION_ID} +import org.apache.spark.sql.util.MockEnvironment +import org.apache.spark.util.ThreadUtils + +class FlintREPLITSuite extends SparkFunSuite with OpenSearchSuite with JobTest { + + var flintClient: FlintClient = _ + var osClient: OSClient = _ + var updater: OpenSearchUpdater = _ + val requestIndex = "flint_ql_sessions" + val resultIndex = "query_results2" + val jobRunId = "00ff4o3b5091080q" + val appId = "00feq82b752mbt0p" + val dataSourceName = "my_glue1" + val sessionId = "10" + val requestIndexMapping = + """ { + | "properties": { + | "applicationId": { + | "type": "keyword" + | }, + | "dataSourceName": { + | "type": "keyword" + | }, + | "error": { + | "type": "text" + | }, + | "excludeJobIds": { + | "type": "text", + | "fields": { + | "keyword": { + | "type": "keyword", + | "ignore_above": 256 + | } + | } + | }, + | "if_primary_term": { + | "type": "long" + | }, + | "if_seq_no": { + | "type": "long" + | }, + | "jobId": { + | "type": "keyword" + | }, + | "jobStartTime": { + | "type": "long" + | }, + | "lang": { + | "type": "keyword" + | }, + | "lastUpdateTime": { + | "type": "date", + | "format": "strict_date_time||epoch_millis" + | }, + | "query": { + | "type": "text" + | }, + | "queryId": { + | "type": "text", + | "fields": { + | "keyword": { + | "type": "keyword", + | "ignore_above": 256 + | } + | } + | }, + | "sessionId": { + | "type": "keyword" + | }, + | "state": { + | "type": "keyword" + | }, + | "statementId": { + | "type": "keyword" + | }, + | "submitTime": { + | "type": "date", + | "format": "strict_date_time||epoch_millis" + | }, + | "type": { + | "type": "keyword" + | } + | } + | } + |""".stripMargin + val testTable = dataSourceName + ".default.flint_sql_test" + + // use a thread-local variable to store and manage the future in beforeEach and afterEach + val threadLocalFuture = new ThreadLocal[Future[Unit]]() + + override def beforeAll(): Unit = { + super.beforeAll() + + flintClient = new FlintOpenSearchClient(new FlintOptions(openSearchOptions.asJava)); + osClient = new OSClient(new FlintOptions(openSearchOptions.asJava)) + updater = new OpenSearchUpdater( + requestIndex, + new FlintOpenSearchClient(new FlintOptions(openSearchOptions.asJava))) + + } + + override def afterEach(): Unit = { + flintClient.deleteIndex(requestIndex) + super.afterEach() + } + + def createSession(jobId: String, excludeJobId: String): Unit = { + val docs = Seq(s"""{ + | "state": "running", + | "lastUpdateTime": 1698796582978, + | "applicationId": "00fd777k3k3ls20p", + | "error": "", + | "sessionId": ${sessionId}, + | "jobId": \"${jobId}\", + | "type": "session", + | "excludeJobIds": [\"${excludeJobId}\"] + |}""".stripMargin) + index(requestIndex, oneNodeSetting, requestIndexMapping, docs) + } + + def startREPL(): Future[Unit] = { + val prefix = "flint-repl-test" + val threadPool = ThreadUtils.newDaemonThreadPoolScheduledExecutor(prefix, 1) + implicit val executionContext = ExecutionContext.fromExecutor(threadPool) + + val futureResult = Future { + // SparkConf's constructor creates a SparkConf that loads defaults from system properties and the classpath. + // Read SparkConf.getSystemProperties + System.setProperty(DATA_SOURCE_NAME.key, "my_glue1") + System.setProperty(JOB_TYPE.key, "interactive") + System.setProperty(SESSION_ID.key, sessionId) + System.setProperty(REQUEST_INDEX.key, requestIndex) + System.setProperty(EXCLUDE_JOB_IDS.key, "00fer5qo32fa080q") + System.setProperty(REPL_INACTIVITY_TIMEOUT_MILLIS.key, "5000") + System.setProperty( + s"spark.sql.catalog.my_glue1", + "org.opensearch.sql.FlintDelegatingSessionCatalog") + System.setProperty("spark.master", "local") + System.setProperty(HOST_ENDPOINT.key, openSearchHost) + System.setProperty(HOST_PORT.key, String.valueOf(openSearchPort)) + System.setProperty(REFRESH_POLICY.key, "true") + + FlintREPL.envinromentProvider = new MockEnvironment( + Map("SERVERLESS_EMR_JOB_ID" -> jobRunId, "SERVERLESS_EMR_VIRTUAL_CLUSTER_ID" -> appId)) + FlintREPL.enableHiveSupport = false + FlintREPL.terminateJVM = false + FlintREPL.main(Array("select 1", resultIndex)) + } + futureResult.onComplete { + case Success(result) => logInfo(s"Success result: $result") + case Failure(ex) => + ex.printStackTrace() + assert(false, s"An error has occurred: ${ex.getMessage}") + } + futureResult + } + + def waitREPLStop(future: Future[Unit]): Unit = { + try { + ThreadUtils.awaitResult(future, Duration(1, MINUTES)) + } catch { + case e: Exception => + e.printStackTrace() + assert(false, "failure waiting for REPL to finish") + } + } + + def submitQuery(query: String, queryId: String): String = { + submitQuery(query, queryId, System.currentTimeMillis()) + } + + def submitQuery(query: String, queryId: String, submitTime: Long): String = { + val statementId = UUID.randomUUID().toString + + updater.upsert( + statementId, + s"""{ + | "sessionId": "${sessionId}", + | "query": "${query}", + | "applicationId": "00fd775baqpu4g0p", + | "state": "waiting", + | "submitTime": $submitTime, + | "type": "statement", + | "statementId": "${statementId}", + | "queryId": "${queryId}", + | "dataSourceName": "${dataSourceName}" + |}""".stripMargin) + statementId + } + + test("sanity") { + try { + createSession(jobRunId, "") + threadLocalFuture.set(startREPL()) + + val createStatement = + s""" + | CREATE TABLE $testTable + | ( + | name STRING, + | age INT + | ) + | USING CSV + | OPTIONS ( + | header 'false', + | delimiter '\\t' + | ) + |""".stripMargin + submitQuery(s"${makeJsonCompliant(createStatement)}", "99") + + val insertStatement = + s""" + | INSERT INTO $testTable + | VALUES ('Hello', 30) + | """.stripMargin + submitQuery(s"${makeJsonCompliant(insertStatement)}", "100") + + val selectQueryId = "101" + val selectQueryStartTime = System.currentTimeMillis() + val selectQuery = s"SELECT name, age FROM $testTable".stripMargin + val selectStatementId = submitQuery(s"${makeJsonCompliant(selectQuery)}", selectQueryId) + + val describeStatement = s"DESC $testTable".stripMargin + val descQueryId = "102" + val descStartTime = System.currentTimeMillis() + val descStatementId = submitQuery(s"${makeJsonCompliant(describeStatement)}", descQueryId) + + val showTableStatement = + s"SHOW TABLES IN " + dataSourceName + ".default LIKE 'flint_sql_test'" + val showQueryId = "103" + val showStartTime = System.currentTimeMillis() + val showTableStatementId = + submitQuery(s"${makeJsonCompliant(showTableStatement)}", showQueryId) + + val wrongSelectQueryId = "104" + val wrongSelectQueryStartTime = System.currentTimeMillis() + val wrongSelectQuery = s"SELECT name, age FROM testTable".stripMargin + val wrongSelectStatementId = + submitQuery(s"${makeJsonCompliant(wrongSelectQuery)}", wrongSelectQueryId) + + val lateSelectQueryId = "105" + val lateSelectQuery = s"SELECT name, age FROM $testTable".stripMargin + // submitted from last year. We won't pick it up + val lateSelectStatementId = + submitQuery(s"${makeJsonCompliant(selectQuery)}", selectQueryId, 1672101970000L) + + // clean up + val dropStatement = + s"""DROP TABLE $testTable""".stripMargin + submitQuery(s"${makeJsonCompliant(dropStatement)}", "999") + + val selectQueryValidation: REPLResult => Boolean = result => { + assert( + result.results.size == 1, + s"expected result size is 1, but got ${result.results.size}") + val expectedResult = "{'name':'Hello','age':30}" + assert( + result.results(0).equals(expectedResult), + s"expected result is $expectedResult, but got ${result.results(0)}") + assert( + result.schemas.size == 2, + s"expected schema size is 2, but got ${result.schemas.size}") + val expectedZerothSchema = "{'column_name':'name','data_type':'string'}" + assert( + result.schemas(0).equals(expectedZerothSchema), + s"expected first field is $expectedZerothSchema, but got ${result.schemas(0)}") + val expectedFirstSchema = "{'column_name':'age','data_type':'integer'}" + assert( + result.schemas(1).equals(expectedFirstSchema), + s"expected second field is $expectedFirstSchema, but got ${result.schemas(1)}") + commonValidation(result, selectQueryId, selectQuery, selectQueryStartTime) + successValidation(result) + true + } + pollForResultAndAssert(selectQueryValidation, selectQueryId) + assert( + !awaitConditionForStatementOrTimeout( + statement => { + statement.state == "success" + }, + selectStatementId), + s"Fail to verify for $selectStatementId.") + + val descValidation: REPLResult => Boolean = result => { + assert( + result.results.size == 2, + s"expected result size is 2, but got ${result.results.size}") + val expectedResult0 = "{'col_name':'name','data_type':'string'}" + assert( + result.results(0).equals(expectedResult0), + s"expected result is $expectedResult0, but got ${result.results(0)}") + val expectedResult1 = "{'col_name':'age','data_type':'int'}" + assert( + result.results(1).equals(expectedResult1), + s"expected result is $expectedResult1, but got ${result.results(1)}") + assert( + result.schemas.size == 3, + s"expected schema size is 3, but got ${result.schemas.size}") + val expectedZerothSchema = "{'column_name':'col_name','data_type':'string'}" + assert( + result.schemas(0).equals(expectedZerothSchema), + s"expected first field is $expectedZerothSchema, but got ${result.schemas(0)}") + val expectedFirstSchema = "{'column_name':'data_type','data_type':'string'}" + assert( + result.schemas(1).equals(expectedFirstSchema), + s"expected second field is $expectedFirstSchema, but got ${result.schemas(1)}") + val expectedSecondSchema = "{'column_name':'comment','data_type':'string'}" + assert( + result.schemas(2).equals(expectedSecondSchema), + s"expected third field is $expectedSecondSchema, but got ${result.schemas(2)}") + commonValidation(result, descQueryId, describeStatement, descStartTime) + successValidation(result) + true + } + pollForResultAndAssert(descValidation, descQueryId) + assert( + !awaitConditionForStatementOrTimeout( + statement => { + statement.state == "success" + }, + descStatementId), + s"Fail to verify for $descStatementId.") + + val showValidation: REPLResult => Boolean = result => { + assert( + result.results.size == 1, + s"expected result size is 1, but got ${result.results.size}") + val expectedResult = + "{'namespace':'default','tableName':'flint_sql_test','isTemporary':false}" + assert( + result.results(0).equals(expectedResult), + s"expected result is $expectedResult, but got ${result.results(0)}") + assert( + result.schemas.size == 3, + s"expected schema size is 3, but got ${result.schemas.size}") + val expectedZerothSchema = "{'column_name':'namespace','data_type':'string'}" + assert( + result.schemas(0).equals(expectedZerothSchema), + s"expected first field is $expectedZerothSchema, but got ${result.schemas(0)}") + val expectedFirstSchema = "{'column_name':'tableName','data_type':'string'}" + assert( + result.schemas(1).equals(expectedFirstSchema), + s"expected second field is $expectedFirstSchema, but got ${result.schemas(1)}") + val expectedSecondSchema = "{'column_name':'isTemporary','data_type':'boolean'}" + assert( + result.schemas(2).equals(expectedSecondSchema), + s"expected third field is $expectedSecondSchema, but got ${result.schemas(2)}") + commonValidation(result, showQueryId, showTableStatement, showStartTime) + successValidation(result) + true + } + pollForResultAndAssert(showValidation, showQueryId) + assert( + !awaitConditionForStatementOrTimeout( + statement => { + statement.state == "success" + }, + showTableStatementId), + s"Fail to verify for $showTableStatementId.") + + val wrongSelectQueryValidation: REPLResult => Boolean = result => { + assert( + result.results.size == 0, + s"expected result size is 0, but got ${result.results.size}") + assert( + result.schemas.size == 0, + s"expected schema size is 0, but got ${result.schemas.size}") + commonValidation(result, wrongSelectQueryId, wrongSelectQuery, wrongSelectQueryStartTime) + failureValidation(result) + true + } + pollForResultAndAssert(wrongSelectQueryValidation, wrongSelectQueryId) + assert( + !awaitConditionForStatementOrTimeout( + statement => { + statement.state == "failed" + }, + wrongSelectStatementId), + s"Fail to verify for $wrongSelectStatementId.") + + // expect time out as this statement should not be picked up + assert( + awaitConditionForStatementOrTimeout( + statement => { + statement.state != "waiting" + }, + lateSelectStatementId), + s"Fail to verify for $lateSelectStatementId.") + } catch { + case e: Exception => + logError("Unexpected exception", e) + assert(false, "Unexpected exception") + } finally { + waitREPLStop(threadLocalFuture.get()) + threadLocalFuture.remove() + + // shutdown hook is called after all tests have finished. We cannot verify if session has correctly been set in IT. + } + } + + /** + * JSON does not support raw newlines (\n) in string values. All newlines must be escaped or + * removed when inside a JSON string. The same goes for tab characters, which should be + * represented as \\t. + * + * Here, I replace the newlines with spaces and escape tab characters that is being included in + * the JSON. + * + * @param sqlQuery + * @return + */ + def makeJsonCompliant(sqlQuery: String): String = { + sqlQuery.replaceAll("\n", " ").replaceAll("\t", "\\\\t") + } + + def commonValidation( + result: REPLResult, + expectedQueryId: String, + expectedStatement: String, + queryStartTime: Long): Unit = { + assert( + result.jobRunId.equals(jobRunId), + s"expected job id is $jobRunId, but got ${result.jobRunId}") + assert( + result.applicationId.equals(appId), + s"expected app id is $appId, but got ${result.applicationId}") + assert( + result.dataSourceName.equals(dataSourceName), + s"expected data source is $dataSourceName, but got ${result.dataSourceName}") + assert( + result.queryId.equals(expectedQueryId), + s"expected query id is $expectedQueryId, but got ${result.queryId}") + assert( + result.queryText.equals(expectedStatement), + s"expected query is $expectedStatement, but got ${result.queryText}") + assert( + result.sessionId.equals(sessionId), + s"expected session id is $sessionId, but got ${result.sessionId}") + assert( + result.updateTime > queryStartTime, + s"expect that update time is ${result.updateTime} later than query start time $queryStartTime, but it is not") + assert( + result.queryRunTime > 0, + s"expected query run time is positive, but got ${result.queryRunTime}") + assert( + result.queryRunTime < System.currentTimeMillis() - queryStartTime, + s"expected query run time ${result.queryRunTime} should be less than ${System + .currentTimeMillis() - queryStartTime}, but it is not") + } + + def successValidation(result: REPLResult): Unit = { + assert( + result.status.equals("SUCCESS"), + s"expected status is SUCCESS, but got ${result.status}") + assert(result.error.isEmpty, s"we don't expect error, but got ${result.error}") + } + + def failureValidation(result: REPLResult): Unit = { + assert(result.status.equals("FAILED"), s"expected status is FAILED, but got ${result.status}") + assert(!result.error.isEmpty, s"we expect error, but got nothing") + } + + def pollForResultAndAssert(expected: REPLResult => Boolean, queryId: String): Unit = { + pollForResultAndAssert(osClient, expected, "queryId", queryId, 60000, resultIndex) + } + + /** + * Repeatedly polls a resource until a specified condition is met or a timeout occurs. + * + * This method continuously checks a resource for a specific condition. If the condition is met + * within the timeout period, the polling stops. If the timeout period is exceeded without the + * condition being met, an assertion error is thrown. + * + * @param osClient + * The OSClient used to poll the resource. + * @param condition + * A function that takes an instance of type T and returns a Boolean. This function defines + * the condition to be met. + * @param id + * The unique identifier of the resource to be polled. + * @param timeoutMillis + * The maximum amount of time (in milliseconds) to wait for the condition to be met. + * @param index + * The index in which the resource resides. + * @param deserialize + * A function that deserializes a String into an instance of type T. + * @param logType + * A descriptive string for logging purposes, indicating the type of resource being polled. + * @return + * whether timeout happened + * @throws OpenSearchStatusException + * if there's an issue fetching the resource. + */ + def awaitConditionOrTimeout[T]( + osClient: OSClient, + expected: T => Boolean, + id: String, + timeoutMillis: Long, + index: String, + deserialize: String => T, + logType: String): Boolean = { + val getResponse = osClient.getDoc(index, id) + val startTime = System.currentTimeMillis() + breakable { + while (System.currentTimeMillis() - startTime < timeoutMillis) { + logInfo(s"Check $logType for $id") + try { + if (getResponse.isExists()) { + val instance = deserialize(getResponse.getSourceAsString) + logInfo(s"$logType $id: $instance") + if (expected(instance)) { + break + } + } + } catch { + case e: OpenSearchStatusException => logError(s"Exception while fetching $logType", e) + } + Thread.sleep(2000) // 2 seconds + } + } + System.currentTimeMillis() - startTime >= timeoutMillis + } + + def awaitConditionForStatementOrTimeout( + expected: FlintCommand => Boolean, + statementId: String): Boolean = { + awaitConditionOrTimeout[FlintCommand]( + osClient, + expected, + statementId, + 10000, + requestIndex, + FlintCommand.deserialize, + "statement") + } + + def awaitConditionForSessionOrTimeout( + expected: FlintInstance => Boolean, + sessionId: String): Boolean = { + awaitConditionOrTimeout[FlintInstance]( + osClient, + expected, + sessionId, + 10000, + requestIndex, + FlintInstance.deserialize, + "session") + } +} diff --git a/integ-test/src/test/scala/org/apache/spark/sql/JobTest.scala b/integ-test/src/test/scala/org/apache/spark/sql/JobTest.scala new file mode 100644 index 000000000..563997b7f --- /dev/null +++ b/integ-test/src/test/scala/org/apache/spark/sql/JobTest.scala @@ -0,0 +1,88 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import scala.collection.JavaConverters._ +import scala.util.{Failure, Success} +import scala.util.control.Breaks._ + +import org.opensearch.OpenSearchStatusException +import org.opensearch.flint.OpenSearchSuite +import org.opensearch.flint.core.FlintOptions +import org.opensearch.search.sort.SortOrder + +import org.apache.spark.internal.Logging + +/** + * We use a self-type annotation (self: OpenSearchSuite =>) to specify that it must be mixed into + * a class that also mixes in OpenSearchSuite. This way, JobTest can still use the + * openSearchOptions field, + */ +trait JobTest extends Logging { self: OpenSearchSuite => + + def pollForResultAndAssert( + osClient: OSClient, + expected: REPLResult => Boolean, + idField: String, + idValue: String, + timeoutMillis: Long, + resultIndex: String): Unit = { + val query = + s"""{ + | "bool": { + | "must": [ + | { + | "term": { + | "$idField": "$idValue" + | } + | } + | ] + | } + |}""".stripMargin + val resultReader = osClient.createQueryReader(resultIndex, query, "updateTime", SortOrder.ASC) + + val startTime = System.currentTimeMillis() + breakable { + while (System.currentTimeMillis() - startTime < timeoutMillis) { + logInfo(s"Check result for $idValue") + try { + if (resultReader.hasNext()) { + REPLResult.deserialize(resultReader.next()) match { + case Success(replResult) => + logInfo(s"repl result: $replResult") + assert(expected(replResult), s"{$query} failed.") + case Failure(exception) => + assert(false, "Failed to deserialize: " + exception.getMessage) + } + break + } + } catch { + case e: OpenSearchStatusException => logError("Exception while querying for result", e) + } + + Thread.sleep(2000) // 2 seconds + } + if (System.currentTimeMillis() - startTime >= timeoutMillis) { + assert( + false, + s"Timeout occurred after $timeoutMillis milliseconds waiting for query result.") + } + } + } + + /** + * Used to preprocess multi-line queries before comparing them as serialized and deserialized + * queries might have different characters. + * @param s + * input + * @return + * normalized input by replacing all space, tab, ane newlines with single spaces. + */ + def normalizeString(s: String): String = { + // \\s+ is a regular expression that matches one or more whitespace characters, including spaces, tabs, and newlines. + s.replaceAll("\\s+", " ") + } // Replace all whitespace characters with empty string +} diff --git a/integ-test/src/test/scala/org/apache/spark/sql/REPLResult.scala b/integ-test/src/test/scala/org/apache/spark/sql/REPLResult.scala new file mode 100644 index 000000000..34dc2595c --- /dev/null +++ b/integ-test/src/test/scala/org/apache/spark/sql/REPLResult.scala @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import scala.util.Try + +import org.json4s.{DefaultFormats, Formats} +import org.json4s.native.JsonMethods.parse + +class REPLResult( + val results: Seq[String], + val schemas: Seq[String], + val jobRunId: String, + val applicationId: String, + val dataSourceName: String, + val status: String, + val error: String, + val queryId: String, + val queryText: String, + val sessionId: String, + val updateTime: Long, + val queryRunTime: Long) { + override def toString: String = { + s"REPLResult(results=$results, schemas=$schemas, jobRunId=$jobRunId, applicationId=$applicationId, " + + s"dataSourceName=$dataSourceName, status=$status, error=$error, queryId=$queryId, queryText=$queryText, " + + s"sessionId=$sessionId, updateTime=$updateTime, queryRunTime=$queryRunTime)" + } +} + +object REPLResult { + implicit val formats: Formats = DefaultFormats + + def deserialize(jsonString: String): Try[REPLResult] = Try { + val json = parse(jsonString) + + new REPLResult( + results = (json \ "result").extract[Seq[String]], + schemas = (json \ "schema").extract[Seq[String]], + jobRunId = (json \ "jobRunId").extract[String], + applicationId = (json \ "applicationId").extract[String], + dataSourceName = (json \ "dataSourceName").extract[String], + status = (json \ "status").extract[String], + error = (json \ "error").extract[String], + queryId = (json \ "queryId").extract[String], + queryText = (json \ "queryText").extract[String], + sessionId = (json \ "sessionId").extract[String], + updateTime = (json \ "updateTime").extract[Long], + queryRunTime = (json \ "queryRunTime").extract[Long]) + } +} diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index 7af1c2639..4ab3a983b 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -51,6 +51,7 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit protected def deleteTestIndex(testIndexNames: String*): Unit = { testIndexNames.foreach(testIndex => { + /** * Todo, if state is not valid, will throw IllegalStateException. Should check flint * .isRefresh before cleanup resource. Current solution, (1) try to delete flint index, (2) diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCorrelationITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCorrelationITSuite.scala index 61564546e..575f09362 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCorrelationITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCorrelationITSuite.scala @@ -142,7 +142,8 @@ class FlintSparkPPLCorrelationITSuite assert( thrown.getMessage === "Correlation command was called with `fields` attribute having different elements from the 'mapping' attributes ") } - test("create failing ppl correlation query with no scope - due to mismatch fields to mappings test") { + test( + "create failing ppl correlation query with no scope - due to mismatch fields to mappings test") { val thrown = intercept[IllegalStateException] { val frame = sql(s""" | source = $testTable1, $testTable2| correlate exact fields(name, country) mapping($testTable1.name = $testTable2.name) diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala index 62ff50fb6..32c1baa0a 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala @@ -331,6 +331,7 @@ class FlintSparkPPLFiltersITSuite // Compare the two plans assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } + /** * | age_span | country | average_age | * |:---------|:--------|:------------| diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala index 750e228ef..df0bf5c4e 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala @@ -52,9 +52,13 @@ object FlintJob extends Logging with FlintJobExecutor { * Without this setup, Spark would not recognize names in the format `my_glue1.default`. */ conf.set("spark.sql.defaultCatalog", dataSource) - val jobOperator = - JobOperator(conf, query, dataSource, resultIndex, wait.equalsIgnoreCase("streaming")) + JobOperator( + createSparkSession(conf), + query, + dataSource, + resultIndex, + wait.equalsIgnoreCase("streaming")) jobOperator.start() } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index a44e70401..4a3c03d9b 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -18,11 +18,12 @@ import play.api.libs.json.{JsArray, JsBoolean, JsObject, Json, JsString, JsValue import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging import org.apache.spark.sql.FlintJob.{checkAndCreateIndex, createIndex, currentTimeProvider, executeQuery, getFailedData, getFormattedData, isSuperset, logError, logInfo, processQueryException, writeDataFrameToOpensearch} +import org.apache.spark.sql.FlintREPL.envinromentProvider import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.types.{ArrayType, LongType, StringType, StructField, StructType} -import org.apache.spark.sql.util.{DefaultThreadPoolFactory, RealTimeProvider, ThreadPoolFactory, TimeProvider} +import org.apache.spark.sql.util.{DefaultThreadPoolFactory, EnvironmentProvider, RealEnvironment, RealTimeProvider, ThreadPoolFactory, TimeProvider} import org.apache.spark.util.ThreadUtils trait FlintJobExecutor { @@ -30,6 +31,8 @@ trait FlintJobExecutor { var currentTimeProvider: TimeProvider = new RealTimeProvider() var threadPoolFactory: ThreadPoolFactory = new DefaultThreadPoolFactory() + var envinromentProvider: EnvironmentProvider = new RealEnvironment() + var enableHiveSupport: Boolean = true // The enabled setting, which can be applied only to the top-level mapping definition and to object fields, val resultIndexMapping = @@ -87,7 +90,11 @@ trait FlintJobExecutor { } def createSparkSession(conf: SparkConf): SparkSession = { - SparkSession.builder().config(conf).enableHiveSupport().getOrCreate() + val builder = SparkSession.builder().config(conf) + if (enableHiveSupport) { + builder.enableHiveSupport() + } + builder.getOrCreate() } private def writeData(resultData: DataFrame, resultIndex: String): Unit = { @@ -177,8 +184,8 @@ trait FlintJobExecutor { ( resultToSave, resultSchemaToSave, - sys.env.getOrElse("SERVERLESS_EMR_JOB_ID", "unknown"), - sys.env.getOrElse("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"), + envinromentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown"), + envinromentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"), dataSource, "SUCCESS", "", @@ -226,8 +233,8 @@ trait FlintJobExecutor { ( null, null, - sys.env.getOrElse("SERVERLESS_EMR_JOB_ID", "unknown"), - sys.env.getOrElse("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"), + envinromentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown"), + envinromentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown"), dataSource, "FAILED", error, @@ -310,7 +317,8 @@ trait FlintJobExecutor { } } catch { case e: IllegalStateException - if e.getCause().getMessage().contains("index_not_found_exception") => + if e.getCause != null && + e.getCause.getMessage.contains("index_not_found_exception") => createIndex(osClient, resultIndex, resultIndexMapping) case e: InterruptedException => val error = s"Interrupted by the main thread: ${e.getMessage}" diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 6c3fd957d..2a63653e3 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -18,11 +18,15 @@ import org.opensearch.action.get.GetResponse import org.opensearch.common.Strings import org.opensearch.flint.app.{FlintCommand, FlintInstance} import org.opensearch.flint.app.FlintInstance.formats +import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} +import org.opensearch.search.sort.SortOrder import org.apache.spark.SparkConf import org.apache.spark.internal.Logging +import org.apache.spark.sql.FlintJob.createSparkSession import org.apache.spark.sql.flint.config.FlintSparkConf +import org.apache.spark.sql.flint.config.FlintSparkConf.REPL_INACTIVITY_TIMEOUT_MILLIS import org.apache.spark.sql.util.{DefaultShutdownHookManager, ShutdownHookManagerTrait} import org.apache.spark.util.ThreadUtils @@ -42,13 +46,16 @@ import org.apache.spark.util.ThreadUtils object FlintREPL extends Logging with FlintJobExecutor { private val HEARTBEAT_INTERVAL_MILLIS = 60000L - private val DEFAULT_INACTIVITY_LIMIT_MILLIS = 10 * 60 * 1000 private val MAPPING_CHECK_TIMEOUT = Duration(1, MINUTES) private val DEFAULT_QUERY_EXECUTION_TIMEOUT = Duration(30, MINUTES) private val DEFAULT_QUERY_WAIT_TIMEOUT_MILLIS = 10 * 60 * 1000 val INITIAL_DELAY_MILLIS = 3000L val EARLY_TERMIANTION_CHECK_FREQUENCY = 60000L + @volatile var earlyExitFlag: Boolean = false + // termiante JVM in the presence non-deamon thread before exiting + var terminateJVM = true + def updateSessionIndex(flintCommand: FlintCommand, updater: OpenSearchUpdater): Unit = { updater.update(flintCommand.statementId, FlintCommand.serialize(flintCommand)) } @@ -61,7 +68,7 @@ object FlintREPL extends Logging with FlintJobExecutor { // init SparkContext val conf: SparkConf = createSparkConf() - val dataSource = conf.get("spark.flint.datasource.name", "unknown") + val dataSource = conf.get(FlintSparkConf.DATA_SOURCE_NAME.key, "unknown") // https://github.com/opensearch-project/opensearch-spark/issues/138 /* * To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`, @@ -71,33 +78,36 @@ object FlintREPL extends Logging with FlintJobExecutor { * Without this setup, Spark would not recognize names in the format `my_glue1.default`. */ conf.set("spark.sql.defaultCatalog", dataSource) - val wait = conf.get("spark.flint.job.type", "continue") + val wait = conf.get(FlintSparkConf.JOB_TYPE.key, "continue") if (wait.equalsIgnoreCase("streaming")) { logInfo(s"""streaming query ${query}""") val jobOperator = - JobOperator(conf, query, dataSource, resultIndex, true) + JobOperator(createSparkSession(conf), query, dataSource, resultIndex, true) jobOperator.start() } else { // we don't allow default value for sessionIndex and sessionId. Throw exception if key not found. - val sessionIndex: Option[String] = Option(conf.get("spark.flint.job.requestIndex", null)) - val sessionId: Option[String] = Option(conf.get("spark.flint.job.sessionId", null)) + val sessionIndex: Option[String] = Option(conf.get(FlintSparkConf.REQUEST_INDEX.key, null)) + val sessionId: Option[String] = Option(conf.get(FlintSparkConf.SESSION_ID.key, null)) if (sessionIndex.isEmpty) { - throw new IllegalArgumentException("spark.flint.job.requestIndex is not set") + throw new IllegalArgumentException(FlintSparkConf.REQUEST_INDEX.key + " is not set") } if (sessionId.isEmpty) { - throw new IllegalArgumentException("spark.flint.job.sessionId is not set") + throw new IllegalArgumentException(FlintSparkConf.SESSION_ID.key + " is not set") } val spark = createSparkSession(conf) val osClient = new OSClient(FlintSparkConf().flintOptions()) - val jobId = sys.env.getOrElse("SERVERLESS_EMR_JOB_ID", "unknown") - val applicationId = sys.env.getOrElse("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown") + val jobId = envinromentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown") + val applicationId = + envinromentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown") // Read the values from the Spark configuration or fall back to the default values val inactivityLimitMillis: Long = - conf.getLong("spark.flint.job.inactivityLimitMillis", DEFAULT_INACTIVITY_LIMIT_MILLIS) + conf.getLong( + FlintSparkConf.REPL_INACTIVITY_TIMEOUT_MILLIS.key, + FlintOptions.DEFAULT_INACTIVITY_LIMIT_MILLIS) val queryExecutionTimeoutSecs: Duration = Duration( conf.getLong( "spark.flint.job.queryExecutionTimeoutSec", @@ -136,6 +146,7 @@ object FlintREPL extends Logging with FlintJobExecutor { applicationId, flintSessionIndexUpdater, jobStartTime)) { + earlyExitFlag = true return } @@ -151,7 +162,6 @@ object FlintREPL extends Logging with FlintJobExecutor { queryExecutionTimeoutSecs, inactivityLimitMillis, queryWaitTimeoutMillis) - exponentialBackoffRetry(maxRetries = 5, initialDelay = 2.seconds) { queryLoop(commandContext) } @@ -177,12 +187,12 @@ object FlintREPL extends Logging with FlintJobExecutor { // Check for non-daemon threads that may prevent the driver from shutting down. // Non-daemon threads other than the main thread indicate that the driver is still processing tasks, // which may be due to unresolved bugs in dependencies or threads not being properly shut down. - if (threadPoolFactory.hasNonDaemonThreadsOtherThanMain) { + if (terminateJVM && threadPoolFactory.hasNonDaemonThreadsOtherThanMain) { logInfo("A non-daemon thread in the driver is seen.") // Exit the JVM to prevent resource leaks and potential emr-s job hung. // A zero status code is used for a graceful shutdown without indicating an error. // If exiting with non-zero status, emr-s job will fail. - // This is a part of the fault tolerance mechanism to handle such scenarios gracefully. + // This is a part of the fault tolerance mechanism to handle such scenarios gracefully System.exit(0) } } @@ -232,7 +242,7 @@ object FlintREPL extends Logging with FlintJobExecutor { applicationId: String, flintSessionIndexUpdater: OpenSearchUpdater, jobStartTime: Long): Boolean = { - val confExcludeJobsOpt = conf.getOption("spark.flint.deployment.excludeJobs") + val confExcludeJobsOpt = conf.getOption(FlintSparkConf.EXCLUDE_JOB_IDS.key) confExcludeJobsOpt match { case None => @@ -505,6 +515,7 @@ object FlintREPL extends Logging with FlintJobExecutor { } if (!canPickNextStatementResult) { + earlyExitFlag = true canProceed = false } else if (!flintReader.hasNext) { canProceed = false @@ -559,9 +570,6 @@ object FlintREPL extends Logging with FlintJobExecutor { osClient: OSClient): Unit = { try { dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) - // todo. it is migration plan to handle https://github - // .com/opensearch-project/sql/issues/2436. Remove sleep after issue fixed in plugin. - Thread.sleep(2000) if (flintCommand.isRunning() || flintCommand.isWaiting()) { // we have set failed state in exception handling flintCommand.complete() @@ -814,7 +822,7 @@ object FlintREPL extends Logging with FlintJobExecutor { | } |}""".stripMargin - val flintReader = osClient.createReader(sessionIndex, dsl, "submitTime") + val flintReader = osClient.createQueryReader(sessionIndex, dsl, "submitTime", SortOrder.ASC) flintReader } @@ -838,7 +846,15 @@ object FlintREPL extends Logging with FlintJobExecutor { } val state = Option(source.get("state")).map(_.asInstanceOf[String]) - if (state.isDefined && state.get != "dead" && state.get != "fail") { + // It's essential to check the earlyExitFlag before marking the session state as 'dead'. When this flag is true, + // it indicates that the control plane has already initiated a new session to handle remaining requests for the + // current session. In our SQL setup, marking a session as 'dead' automatically triggers the creation of a new + // session. However, the newly created session (initiated by the control plane) will enter a spin-wait state, + // where it inefficiently waits for certain conditions to be met, leading to unproductive resource consumption + // and eventual timeout. To avoid this issue and prevent the creation of redundant sessions by SQL, we ensure + // the session state is not set to 'dead' when earlyExitFlag is true, thereby preventing unnecessary duplicate + // processing. + if (!earlyExitFlag && state.isDefined && state.get != "dead" && state.get != "fail") { updateFlintInstanceBeforeShutdown( source, getResponse, diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala index c60d250ea..3b5aa474a 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -21,14 +21,13 @@ import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.util.ThreadUtils case class JobOperator( - sparkConf: SparkConf, + spark: SparkSession, query: String, dataSource: String, resultIndex: String, streaming: Boolean) extends Logging with FlintJobExecutor { - private val spark = createSparkSession(sparkConf) // jvm shutdown hook sys.addShutdownHook(stop()) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala index e2e44bddd..cd784e704 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/OSClient.scala @@ -10,6 +10,7 @@ import java.util.ArrayList import java.util.Locale import org.opensearch.action.get.{GetRequest, GetResponse} +import org.opensearch.action.search.{SearchRequest, SearchResponse} import org.opensearch.client.{RequestOptions, RestHighLevelClient} import org.opensearch.client.indices.{CreateIndexRequest, GetIndexRequest, GetIndexResponse} import org.opensearch.client.indices.CreateIndexRequest @@ -18,7 +19,7 @@ import org.opensearch.common.settings.Settings import org.opensearch.common.xcontent.{NamedXContentRegistry, XContentParser, XContentType} import org.opensearch.common.xcontent.DeprecationHandler.IGNORE_DEPRECATIONS import org.opensearch.flint.core.{FlintClient, FlintClientBuilder, FlintOptions} -import org.opensearch.flint.core.storage.{FlintReader, OpenSearchScrollReader, OpenSearchUpdater} +import org.opensearch.flint.core.storage.{FlintReader, OpenSearchQueryReader, OpenSearchScrollReader, OpenSearchUpdater} import org.opensearch.index.query.{AbstractQueryBuilder, MatchAllQueryBuilder, QueryBuilder} import org.opensearch.plugins.SearchPlugin import org.opensearch.search.SearchModule @@ -117,14 +118,14 @@ class OSClient(val flintOptions: FlintOptions) extends Logging { String.format( Locale.ROOT, "Failed to retrieve doc %s from index %s", - osIndexName, - id), + id, + osIndexName), e) } } } - def createReader(indexName: String, query: String, sort: String): FlintReader = try { + def createScrollReader(indexName: String, query: String, sort: String): FlintReader = try { var queryBuilder: QueryBuilder = new MatchAllQueryBuilder if (!Strings.isNullOrEmpty(query)) { val parser = @@ -152,4 +153,24 @@ class OSClient(val flintOptions: FlintOptions) extends Logging { } } } + + def createQueryReader( + indexName: String, + query: String, + sort: String, + sortOrder: SortOrder): FlintReader = try { + var queryBuilder: QueryBuilder = new MatchAllQueryBuilder + if (!Strings.isNullOrEmpty(query)) { + val parser = + XContentType.JSON.xContent.createParser(xContentRegistry, IGNORE_DEPRECATIONS, query) + queryBuilder = AbstractQueryBuilder.parseInnerQueryBuilder(parser) + } + new OpenSearchQueryReader( + flintClient.createClient(), + indexName, + new SearchSourceBuilder().query(queryBuilder).sort(sort, sortOrder)) + } catch { + case e: IOException => + throw new RuntimeException(e) + } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/util/EnvironmentProvider.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/util/EnvironmentProvider.scala new file mode 100644 index 000000000..ff167444e --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/util/EnvironmentProvider.scala @@ -0,0 +1,10 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql.util + +trait EnvironmentProvider { + def getEnvVar(name: String, default: String): String +} diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/util/RealEnvironment.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/util/RealEnvironment.scala new file mode 100644 index 000000000..b01a2fdec --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/util/RealEnvironment.scala @@ -0,0 +1,10 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql.util + +class RealEnvironment extends EnvironmentProvider { + def getEnvVar(name: String, default: String): String = sys.env.getOrElse(name, default) +} diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index c3d027102..3e9d408e6 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -15,6 +15,7 @@ import scala.concurrent.duration._ import scala.concurrent.duration.{Duration, MINUTES} import scala.reflect.runtime.universe.TypeTag +import org.mockito.ArgumentMatchers.{eq => eqTo, _} import org.mockito.ArgumentMatchersSugar import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock @@ -22,12 +23,13 @@ import org.mockito.stubbing.Answer import org.opensearch.action.get.GetResponse import org.opensearch.flint.app.FlintCommand import org.opensearch.flint.core.storage.{FlintReader, OpenSearchReader, OpenSearchUpdater} +import org.opensearch.search.sort.SortOrder import org.scalatestplus.mockito.MockitoSugar import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.trees.Origin -import org.apache.spark.sql.types.{ArrayType, LongType, NullType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{LongType, NullType, StringType, StructField, StructType} import org.apache.spark.sql.util.{DefaultThreadPoolFactory, MockThreadPoolFactory, MockTimeProvider, RealTimeProvider, ShutdownHookManagerTrait} import org.apache.spark.util.ThreadUtils @@ -411,7 +413,8 @@ class FlintREPLTest new ConnectException( "Timeout connecting to [search-foo-1-bar.eu-west-1.es.amazonaws.com:443]")) val osClient = mock[OSClient] - when(osClient.createReader(any[String], any[String], any[String])).thenReturn(mockReader) + when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) when(mockReader.hasNext).thenThrow(exception) val maxRetries = 1 @@ -686,7 +689,8 @@ class FlintREPLTest test("queryLoop continue until inactivity limit is reached") { val mockReader = mock[FlintReader] val osClient = mock[OSClient] - when(osClient.createReader(any[String], any[String], any[String])).thenReturn(mockReader) + when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) when(mockReader.hasNext).thenReturn(false) val resultIndex = "testResultIndex" @@ -736,7 +740,8 @@ class FlintREPLTest test("queryLoop should stop when canPickUpNextStatement is false") { val mockReader = mock[FlintReader] val osClient = mock[OSClient] - when(osClient.createReader(any[String], any[String], any[String])).thenReturn(mockReader) + when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) when(mockReader.hasNext).thenReturn(true) val resultIndex = "testResultIndex" @@ -790,7 +795,8 @@ class FlintREPLTest test("queryLoop should properly shut down the thread pool after execution") { val mockReader = mock[FlintReader] val osClient = mock[OSClient] - when(osClient.createReader(any[String], any[String], any[String])).thenReturn(mockReader) + when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) when(mockReader.hasNext).thenReturn(false) val resultIndex = "testResultIndex" @@ -838,7 +844,8 @@ class FlintREPLTest test("queryLoop handle exceptions within the loop gracefully") { val mockReader = mock[FlintReader] val osClient = mock[OSClient] - when(osClient.createReader(any[String], any[String], any[String])).thenReturn(mockReader) + when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) // Simulate an exception thrown when hasNext is called when(mockReader.hasNext).thenThrow(new RuntimeException("Test exception")) @@ -889,7 +896,8 @@ class FlintREPLTest test("queryLoop should correctly update loop control variables") { val mockReader = mock[FlintReader] val osClient = mock[OSClient] - when(osClient.createReader(any[String], any[String], any[String])).thenReturn(mockReader) + when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) val getResponse = mock[GetResponse] when(osClient.getDoc(*, *)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(false) @@ -958,7 +966,8 @@ class FlintREPLTest test("queryLoop should execute loop without processing any commands") { val mockReader = mock[FlintReader] val osClient = mock[OSClient] - when(osClient.createReader(any[String], any[String], any[String])).thenReturn(mockReader) + when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) + .thenReturn(mockReader) val getResponse = mock[GetResponse] when(osClient.getDoc(*, *)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(false) diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/util/MockEnvironment.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/util/MockEnvironment.scala new file mode 100644 index 000000000..637c7d91c --- /dev/null +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/util/MockEnvironment.scala @@ -0,0 +1,10 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql.util + +class MockEnvironment(inputMap: Map[String, String]) extends EnvironmentProvider { + def getEnvVar(name: String, default: String): String = inputMap.getOrElse(name, default) +}