From e7d087421430ade0b0ad519cafdfce23023d7c66 Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Thu, 18 Apr 2024 10:51:56 -0700 Subject: [PATCH] Cleanup Spark shuffle data after data is consumed Signed-off-by: Peng Huo --- .../scala/org/apache/spark/sql/FlintJob.scala | 18 ++---- .../apache/spark/sql/FlintJobExecutor.scala | 30 +++++----- .../org/apache/spark/sql/FlintREPL.scala | 12 ++-- .../org/apache/spark/sql/JobOperator.scala | 10 ++-- .../spark/sql/util/ShuffleCleaner.scala | 57 +++++++++++++++++++ .../org/apache/spark/sql/FlintJobTest.scala | 5 +- .../spark/sql/util/CleanerFactoryTest.scala | 23 ++++++++ 7 files changed, 115 insertions(+), 40 deletions(-) create mode 100644 spark-sql-application/src/main/scala/org/apache/spark/sql/util/ShuffleCleaner.scala create mode 100644 spark-sql-application/src/test/scala/org/apache/spark/sql/util/CleanerFactoryTest.scala diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala index 8b4bdeeaf..66859c9f4 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala @@ -6,24 +6,14 @@ // defined in spark package so that I can use ThreadUtils package org.apache.spark.sql -import java.util.Locale -import java.util.concurrent.atomic.AtomicInteger - -import org.opensearch.client.{RequestOptions, RestHighLevelClient} -import org.opensearch.cluster.metadata.MappingMetadata -import org.opensearch.common.settings.Settings -import org.opensearch.common.xcontent.XContentType -import org.opensearch.flint.core.{FlintClient, FlintClientBuilder, FlintOptions} -import org.opensearch.flint.core.metadata.FlintMetadata +import org.apache.spark.internal.Logging +import org.apache.spark.sql.flint.config.FlintSparkConf +import org.apache.spark.sql.types._ import org.opensearch.flint.core.metrics.MetricConstants import org.opensearch.flint.core.metrics.MetricsUtil.registerGauge import play.api.libs.json._ -import org.apache.spark.SparkConf -import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.flint.config.FlintSparkConf -import org.apache.spark.sql.types.{StructField, _} +import java.util.concurrent.atomic.AtomicInteger /** * Spark SQL Application entrypoint diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index ccd5c8f3f..915de4089 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -7,25 +7,17 @@ package org.apache.spark.sql import java.util.Locale -import scala.concurrent.{ExecutionContext, Future, TimeoutException} -import scala.concurrent.duration.{Duration, MINUTES} - import com.amazonaws.services.s3.model.AmazonS3Exception -import org.opensearch.flint.core.{FlintClient, IRestHighLevelClient} -import org.opensearch.flint.core.metadata.FlintMetadata +import org.opensearch.flint.core.IRestHighLevelClient import org.opensearch.flint.core.metrics.MetricConstants import org.opensearch.flint.core.metrics.MetricsUtil.incrementCounter -import play.api.libs.json.{JsArray, JsBoolean, JsObject, Json, JsString, JsValue} +import play.api.libs.json._ import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging -import org.apache.spark.sql.FlintREPL.envinromentProvider import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.flint.config.FlintSparkConf -import org.apache.spark.sql.types.{ArrayType, LongType, StringType, StructField, StructType} -import org.apache.spark.sql.util.{DefaultThreadPoolFactory, EnvironmentProvider, RealEnvironment, RealTimeProvider, ThreadPoolFactory, TimeProvider} -import org.apache.spark.util.ThreadUtils +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util._ trait FlintJobExecutor { this: Logging => @@ -156,7 +148,8 @@ trait FlintJobExecutor { query: String, sessionId: String, startTime: Long, - timeProvider: TimeProvider): DataFrame = { + timeProvider: TimeProvider, + cleaner: Cleaner): DataFrame = { // Create the schema dataframe val schemaRows = result.schema.fields.map { field => Row(field.name, field.dataType.typeName) @@ -191,6 +184,11 @@ trait FlintJobExecutor { val resultSchemaToSave = resultSchema.toJSON.collect.toList.map(_.replaceAll("\"", "'")) val endTime = timeProvider.currentEpochMillis() + // https://github.com/opensearch-project/opensearch-spark/issues/302. Clean shuffle data + // after consumed the query result. Streaming query shuffle data is cleaned after each + // microBatch execution. + cleaner.cleanUp(spark) + // Create the data rows val rows = Seq( ( @@ -366,7 +364,8 @@ trait FlintJobExecutor { query: String, dataSource: String, queryId: String, - sessionId: String): DataFrame = { + sessionId: String, + streaming: Boolean): DataFrame = { // Execute SQL query val startTime = System.currentTimeMillis() // we have to set job group in the same thread that started the query according to spark doc @@ -381,7 +380,8 @@ trait FlintJobExecutor { query, sessionId, startTime, - currentTimeProvider) + currentTimeProvider, + CleanerFactory.cleaner(streaming)) } private def handleQueryException( diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 76e5f692c..b4c27f3c8 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -22,15 +22,13 @@ import org.opensearch.flint.app.{FlintCommand, FlintInstance} import org.opensearch.flint.app.FlintInstance.formats import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.metrics.MetricConstants -import org.opensearch.flint.core.metrics.MetricsUtil.{decrementCounter, getTimerContext, incrementCounter, registerGauge, stopTimer} +import org.opensearch.flint.core.metrics.MetricsUtil.{getTimerContext, incrementCounter, registerGauge, stopTimer} import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater} import org.opensearch.search.sort.SortOrder import org.apache.spark.SparkConf import org.apache.spark.internal.Logging -import org.apache.spark.sql.FlintJob.createSparkSession import org.apache.spark.sql.flint.config.FlintSparkConf -import org.apache.spark.sql.flint.config.FlintSparkConf.REPL_INACTIVITY_TIMEOUT_MILLIS import org.apache.spark.sql.util.{DefaultShutdownHookManager, ShutdownHookManagerTrait} import org.apache.spark.util.ThreadUtils @@ -829,7 +827,13 @@ object FlintREPL extends Logging with FlintJobExecutor { startTime) } else { val futureQueryExecution = Future { - executeQuery(spark, flintCommand.query, dataSource, flintCommand.queryId, sessionId) + executeQuery( + spark, + flintCommand.query, + dataSource, + flintCommand.queryId, + sessionId, + false) }(executionContext) // time out after 10 minutes diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala index 4fb272938..5e68e30e8 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -14,13 +14,10 @@ import scala.util.{Failure, Success, Try} import org.opensearch.flint.core.metrics.MetricConstants import org.opensearch.flint.core.metrics.MetricsUtil.incrementCounter -import org.opensearch.flint.core.storage.OpenSearchUpdater -import org.apache.spark.SparkConf import org.apache.spark.internal.Logging -import org.apache.spark.sql.FlintJob.createSparkSession -import org.apache.spark.sql.FlintREPL.{executeQuery, logInfo, threadPoolFactory, updateFlintInstanceBeforeShutdown} import org.apache.spark.sql.flint.config.FlintSparkConf +import org.apache.spark.sql.util.ShuffleCleaner import org.apache.spark.util.ThreadUtils case class JobOperator( @@ -53,7 +50,7 @@ case class JobOperator( val futureMappingCheck = Future { checkAndCreateIndex(osClient, resultIndex) } - val data = executeQuery(spark, query, dataSource, "", "") + val data = executeQuery(spark, query, dataSource, "", "", streaming) val mappingCheckResult = ThreadUtils.awaitResult(futureMappingCheck, Duration(1, MINUTES)) dataToWrite = Some(mappingCheckResult match { @@ -92,6 +89,8 @@ case class JobOperator( try { // Wait for streaming job complete if no error and there is streaming job running if (!exceptionThrown && streaming && spark.streams.active.nonEmpty) { + // + spark.streams.addListener(new ShuffleCleaner(spark)) // wait if any child thread to finish before the main thread terminates spark.streams.awaitAnyTermination() } @@ -149,4 +148,5 @@ case class JobOperator( case false => incrementCounter(MetricConstants.STREAMING_SUCCESS_METRIC) } } + } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/util/ShuffleCleaner.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/util/ShuffleCleaner.scala new file mode 100644 index 000000000..a9edf9268 --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/util/ShuffleCleaner.scala @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql.util + +import org.apache.spark.{MapOutputTrackerMaster, SparkEnv} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.streaming.StreamingQueryListener + +/** + * Clean Spark shuffle data after each microBatch. + * https://github.com/opensearch-project/opensearch-spark/issues/302 + */ +class ShuffleCleaner(spark: SparkSession) extends StreamingQueryListener with Logging { + + override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = {} + + override def onQueryProgress(event: StreamingQueryListener.QueryProgressEvent): Unit = { + ShuffleCleaner.cleanUp(spark) + } + + override def onQueryTerminated(event: StreamingQueryListener.QueryTerminatedEvent): Unit = {} +} + +trait Cleaner { + def cleanUp(spark: SparkSession) +} + +object CleanerFactory { + def cleaner(streaming: Boolean): Cleaner = { + if (streaming) NoOpCleaner else ShuffleCleaner + } +} + +/** + * No operation cleaner. + */ +object NoOpCleaner extends Cleaner { + override def cleanUp(spark: SparkSession): Unit = {} +} + +/** + * Spark shuffle data cleaner. + */ +object ShuffleCleaner extends Cleaner with Logging { + def cleanUp(spark: SparkSession): Unit = { + logInfo("Before cleanUp Shuffle") + val cleaner = spark.sparkContext.cleaner + val masterTracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + val shuffleIds = masterTracker.shuffleStatuses.keys.toSet + shuffleIds.foreach(shuffleId => cleaner.foreach(c => c.doCleanupShuffle(shuffleId, true))) + logInfo("After cleanUp Shuffle") + } +} diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala index 065c0bb67..352d140ce 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala @@ -7,7 +7,7 @@ package org.apache.spark.sql import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ -import org.apache.spark.sql.util.MockTimeProvider +import org.apache.spark.sql.util.{CleanerFactory, MockTimeProvider} class FlintJobTest extends SparkFunSuite with JobMatchers { @@ -76,7 +76,8 @@ class FlintJobTest extends SparkFunSuite with JobMatchers { "select 1", "20", currentTime - queryRunTime, - new MockTimeProvider(currentTime)) + new MockTimeProvider(currentTime), + CleanerFactory.cleaner(false)) assertEqualDataframe(expected, result) } diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/util/CleanerFactoryTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/util/CleanerFactoryTest.scala new file mode 100644 index 000000000..d061d2354 --- /dev/null +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/util/CleanerFactoryTest.scala @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql.util + +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite + +class CleanerFactoryTest extends SparkFunSuite with Matchers { + + test("CleanerFactory should return NoOpCleaner when streaming is true") { + val cleaner = CleanerFactory.cleaner(streaming = true) + cleaner shouldBe NoOpCleaner + } + + test("CleanerFactory should return ShuffleCleaner when streaming is false") { + val cleaner = CleanerFactory.cleaner(streaming = false) + cleaner shouldBe ShuffleCleaner + } +}