Skip to content

Commit

Permalink
Cleanup Spark shuffle data after data is consumed
Browse files Browse the repository at this point in the history
Signed-off-by: Peng Huo <[email protected]>
  • Loading branch information
penghuo committed Apr 18, 2024
1 parent a38747f commit e7d0874
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,14 @@
// defined in spark package so that I can use ThreadUtils
package org.apache.spark.sql

import java.util.Locale
import java.util.concurrent.atomic.AtomicInteger

import org.opensearch.client.{RequestOptions, RestHighLevelClient}
import org.opensearch.cluster.metadata.MappingMetadata
import org.opensearch.common.settings.Settings
import org.opensearch.common.xcontent.XContentType
import org.opensearch.flint.core.{FlintClient, FlintClientBuilder, FlintOptions}
import org.opensearch.flint.core.metadata.FlintMetadata
import org.apache.spark.internal.Logging
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.types._
import org.opensearch.flint.core.metrics.MetricConstants
import org.opensearch.flint.core.metrics.MetricsUtil.registerGauge
import play.api.libs.json._

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.types.{StructField, _}
import java.util.concurrent.atomic.AtomicInteger

/**
* Spark SQL Application entrypoint
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,17 @@ package org.apache.spark.sql

import java.util.Locale

import scala.concurrent.{ExecutionContext, Future, TimeoutException}
import scala.concurrent.duration.{Duration, MINUTES}

import com.amazonaws.services.s3.model.AmazonS3Exception
import org.opensearch.flint.core.{FlintClient, IRestHighLevelClient}
import org.opensearch.flint.core.metadata.FlintMetadata
import org.opensearch.flint.core.IRestHighLevelClient
import org.opensearch.flint.core.metrics.MetricConstants
import org.opensearch.flint.core.metrics.MetricsUtil.incrementCounter
import play.api.libs.json.{JsArray, JsBoolean, JsObject, Json, JsString, JsValue}
import play.api.libs.json._

import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.FlintREPL.envinromentProvider
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.types.{ArrayType, LongType, StringType, StructField, StructType}
import org.apache.spark.sql.util.{DefaultThreadPoolFactory, EnvironmentProvider, RealEnvironment, RealTimeProvider, ThreadPoolFactory, TimeProvider}
import org.apache.spark.util.ThreadUtils
import org.apache.spark.sql.types._
import org.apache.spark.sql.util._

trait FlintJobExecutor {
this: Logging =>
Expand Down Expand Up @@ -156,7 +148,8 @@ trait FlintJobExecutor {
query: String,
sessionId: String,
startTime: Long,
timeProvider: TimeProvider): DataFrame = {
timeProvider: TimeProvider,
cleaner: Cleaner): DataFrame = {
// Create the schema dataframe
val schemaRows = result.schema.fields.map { field =>
Row(field.name, field.dataType.typeName)
Expand Down Expand Up @@ -191,6 +184,11 @@ trait FlintJobExecutor {
val resultSchemaToSave = resultSchema.toJSON.collect.toList.map(_.replaceAll("\"", "'"))
val endTime = timeProvider.currentEpochMillis()

// https://github.com/opensearch-project/opensearch-spark/issues/302. Clean shuffle data
// after consumed the query result. Streaming query shuffle data is cleaned after each
// microBatch execution.
cleaner.cleanUp(spark)

// Create the data rows
val rows = Seq(
(
Expand Down Expand Up @@ -366,7 +364,8 @@ trait FlintJobExecutor {
query: String,
dataSource: String,
queryId: String,
sessionId: String): DataFrame = {
sessionId: String,
streaming: Boolean): DataFrame = {
// Execute SQL query
val startTime = System.currentTimeMillis()
// we have to set job group in the same thread that started the query according to spark doc
Expand All @@ -381,7 +380,8 @@ trait FlintJobExecutor {
query,
sessionId,
startTime,
currentTimeProvider)
currentTimeProvider,
CleanerFactory.cleaner(streaming))
}

private def handleQueryException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,13 @@ import org.opensearch.flint.app.{FlintCommand, FlintInstance}
import org.opensearch.flint.app.FlintInstance.formats
import org.opensearch.flint.core.FlintOptions
import org.opensearch.flint.core.metrics.MetricConstants
import org.opensearch.flint.core.metrics.MetricsUtil.{decrementCounter, getTimerContext, incrementCounter, registerGauge, stopTimer}
import org.opensearch.flint.core.metrics.MetricsUtil.{getTimerContext, incrementCounter, registerGauge, stopTimer}
import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater}
import org.opensearch.search.sort.SortOrder

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.sql.FlintJob.createSparkSession
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.flint.config.FlintSparkConf.REPL_INACTIVITY_TIMEOUT_MILLIS
import org.apache.spark.sql.util.{DefaultShutdownHookManager, ShutdownHookManagerTrait}
import org.apache.spark.util.ThreadUtils

Expand Down Expand Up @@ -829,7 +827,13 @@ object FlintREPL extends Logging with FlintJobExecutor {
startTime)
} else {
val futureQueryExecution = Future {
executeQuery(spark, flintCommand.query, dataSource, flintCommand.queryId, sessionId)
executeQuery(
spark,
flintCommand.query,
dataSource,
flintCommand.queryId,
sessionId,
false)
}(executionContext)

// time out after 10 minutes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,10 @@ import scala.util.{Failure, Success, Try}

import org.opensearch.flint.core.metrics.MetricConstants
import org.opensearch.flint.core.metrics.MetricsUtil.incrementCounter
import org.opensearch.flint.core.storage.OpenSearchUpdater

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.sql.FlintJob.createSparkSession
import org.apache.spark.sql.FlintREPL.{executeQuery, logInfo, threadPoolFactory, updateFlintInstanceBeforeShutdown}
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.util.ShuffleCleaner
import org.apache.spark.util.ThreadUtils

case class JobOperator(
Expand Down Expand Up @@ -53,7 +50,7 @@ case class JobOperator(
val futureMappingCheck = Future {
checkAndCreateIndex(osClient, resultIndex)
}
val data = executeQuery(spark, query, dataSource, "", "")
val data = executeQuery(spark, query, dataSource, "", "", streaming)

val mappingCheckResult = ThreadUtils.awaitResult(futureMappingCheck, Duration(1, MINUTES))
dataToWrite = Some(mappingCheckResult match {
Expand Down Expand Up @@ -92,6 +89,8 @@ case class JobOperator(
try {
// Wait for streaming job complete if no error and there is streaming job running
if (!exceptionThrown && streaming && spark.streams.active.nonEmpty) {
//
spark.streams.addListener(new ShuffleCleaner(spark))
// wait if any child thread to finish before the main thread terminates
spark.streams.awaitAnyTermination()
}
Expand Down Expand Up @@ -149,4 +148,5 @@ case class JobOperator(
case false => incrementCounter(MetricConstants.STREAMING_SUCCESS_METRIC)
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql.util

import org.apache.spark.{MapOutputTrackerMaster, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.streaming.StreamingQueryListener

/**
* Clean Spark shuffle data after each microBatch.
* https://github.com/opensearch-project/opensearch-spark/issues/302
*/
class ShuffleCleaner(spark: SparkSession) extends StreamingQueryListener with Logging {

override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = {}

override def onQueryProgress(event: StreamingQueryListener.QueryProgressEvent): Unit = {
ShuffleCleaner.cleanUp(spark)
}

override def onQueryTerminated(event: StreamingQueryListener.QueryTerminatedEvent): Unit = {}
}

trait Cleaner {
def cleanUp(spark: SparkSession)
}

object CleanerFactory {
def cleaner(streaming: Boolean): Cleaner = {
if (streaming) NoOpCleaner else ShuffleCleaner
}
}

/**
* No operation cleaner.
*/
object NoOpCleaner extends Cleaner {
override def cleanUp(spark: SparkSession): Unit = {}
}

/**
* Spark shuffle data cleaner.
*/
object ShuffleCleaner extends Cleaner with Logging {
def cleanUp(spark: SparkSession): Unit = {
logInfo("Before cleanUp Shuffle")
val cleaner = spark.sparkContext.cleaner
val masterTracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
val shuffleIds = masterTracker.shuffleStatuses.keys.toSet
shuffleIds.foreach(shuffleId => cleaner.foreach(c => c.doCleanupShuffle(shuffleId, true)))
logInfo("After cleanUp Shuffle")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ package org.apache.spark.sql

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.MockTimeProvider
import org.apache.spark.sql.util.{CleanerFactory, MockTimeProvider}

class FlintJobTest extends SparkFunSuite with JobMatchers {

Expand Down Expand Up @@ -76,7 +76,8 @@ class FlintJobTest extends SparkFunSuite with JobMatchers {
"select 1",
"20",
currentTime - queryRunTime,
new MockTimeProvider(currentTime))
new MockTimeProvider(currentTime),
CleanerFactory.cleaner(false))
assertEqualDataframe(expected, result)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql.util

import org.scalatest.matchers.should.Matchers

import org.apache.spark.SparkFunSuite

class CleanerFactoryTest extends SparkFunSuite with Matchers {

test("CleanerFactory should return NoOpCleaner when streaming is true") {
val cleaner = CleanerFactory.cleaner(streaming = true)
cleaner shouldBe NoOpCleaner
}

test("CleanerFactory should return ShuffleCleaner when streaming is false") {
val cleaner = CleanerFactory.cleaner(streaming = false)
cleaner shouldBe ShuffleCleaner
}
}

0 comments on commit e7d0874

Please sign in to comment.