Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean shuffle data #312

Merged
merged 3 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,15 @@
// defined in spark package so that I can use ThreadUtils
package org.apache.spark.sql

import java.util.Locale
import java.util.concurrent.atomic.AtomicInteger

import org.opensearch.client.{RequestOptions, RestHighLevelClient}
import org.opensearch.cluster.metadata.MappingMetadata
import org.opensearch.common.settings.Settings
import org.opensearch.common.xcontent.XContentType
import org.opensearch.flint.core.{FlintClient, FlintClientBuilder, FlintOptions}
import org.opensearch.flint.core.metadata.FlintMetadata
import org.opensearch.flint.core.metrics.MetricConstants
import org.opensearch.flint.core.metrics.MetricsUtil.registerGauge
import play.api.libs.json._

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.types.{StructField, _}
import org.apache.spark.sql.types._

/**
* Spark SQL Application entrypoint
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,18 @@ package org.apache.spark.sql

import java.util.Locale

import scala.concurrent.{ExecutionContext, Future, TimeoutException}
import scala.concurrent.duration.{Duration, MINUTES}

import com.amazonaws.services.s3.model.AmazonS3Exception
import org.apache.commons.text.StringEscapeUtils.unescapeJava
import org.opensearch.flint.core.{FlintClient, IRestHighLevelClient}
import org.opensearch.flint.core.metadata.FlintMetadata
import org.opensearch.flint.core.IRestHighLevelClient
import org.opensearch.flint.core.metrics.MetricConstants
import org.opensearch.flint.core.metrics.MetricsUtil.incrementCounter
import play.api.libs.json.{JsArray, JsBoolean, JsObject, Json, JsString, JsValue}
import play.api.libs.json._

import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.FlintREPL.envinromentProvider
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.types.{ArrayType, LongType, StringType, StructField, StructType}
import org.apache.spark.sql.util.{DefaultThreadPoolFactory, EnvironmentProvider, RealEnvironment, RealTimeProvider, ThreadPoolFactory, TimeProvider}
import org.apache.spark.util.ThreadUtils
import org.apache.spark.sql.types._
import org.apache.spark.sql.util._

trait FlintJobExecutor {
this: Logging =>
Expand Down Expand Up @@ -157,7 +149,8 @@ trait FlintJobExecutor {
query: String,
sessionId: String,
startTime: Long,
timeProvider: TimeProvider): DataFrame = {
timeProvider: TimeProvider,
cleaner: Cleaner): DataFrame = {
// Create the schema dataframe
val schemaRows = result.schema.fields.map { field =>
Row(field.name, field.dataType.typeName)
Expand Down Expand Up @@ -192,6 +185,11 @@ trait FlintJobExecutor {
val resultSchemaToSave = resultSchema.toJSON.collect.toList.map(_.replaceAll("\"", "'"))
val endTime = timeProvider.currentEpochMillis()

// https://github.com/opensearch-project/opensearch-spark/issues/302. Clean shuffle data
// after consumed the query result. Streaming query shuffle data is cleaned after each
// microBatch execution.
cleaner.cleanUp(spark)

// Create the data rows
val rows = Seq(
(
Expand Down Expand Up @@ -375,7 +373,8 @@ trait FlintJobExecutor {
query: String,
dataSource: String,
queryId: String,
sessionId: String): DataFrame = {
sessionId: String,
streaming: Boolean): DataFrame = {
// Execute SQL query
val startTime = System.currentTimeMillis()
// we have to set job group in the same thread that started the query according to spark doc
Expand All @@ -390,7 +389,8 @@ trait FlintJobExecutor {
query,
sessionId,
startTime,
currentTimeProvider)
currentTimeProvider,
CleanerFactory.cleaner(streaming))
}

private def handleQueryException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,13 @@ import org.opensearch.flint.app.{FlintCommand, FlintInstance}
import org.opensearch.flint.app.FlintInstance.formats
import org.opensearch.flint.core.FlintOptions
import org.opensearch.flint.core.metrics.MetricConstants
import org.opensearch.flint.core.metrics.MetricsUtil.{decrementCounter, getTimerContext, incrementCounter, registerGauge, stopTimer}
import org.opensearch.flint.core.metrics.MetricsUtil.{getTimerContext, incrementCounter, registerGauge, stopTimer}
import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater}
import org.opensearch.search.sort.SortOrder

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.sql.FlintJob.createSparkSession
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.flint.config.FlintSparkConf.REPL_INACTIVITY_TIMEOUT_MILLIS
import org.apache.spark.sql.util.{DefaultShutdownHookManager, ShutdownHookManagerTrait}
import org.apache.spark.util.ThreadUtils

Expand Down Expand Up @@ -829,7 +827,13 @@ object FlintREPL extends Logging with FlintJobExecutor {
startTime)
} else {
val futureQueryExecution = Future {
executeQuery(spark, flintCommand.query, dataSource, flintCommand.queryId, sessionId)
executeQuery(
spark,
flintCommand.query,
dataSource,
flintCommand.queryId,
sessionId,
false)
}(executionContext)

// time out after 10 minutes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,10 @@ import scala.util.{Failure, Success, Try}

import org.opensearch.flint.core.metrics.MetricConstants
import org.opensearch.flint.core.metrics.MetricsUtil.incrementCounter
import org.opensearch.flint.core.storage.OpenSearchUpdater

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.sql.FlintJob.createSparkSession
import org.apache.spark.sql.FlintREPL.{executeQuery, logInfo, threadPoolFactory, updateFlintInstanceBeforeShutdown}
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.util.ShuffleCleaner
import org.apache.spark.util.ThreadUtils

case class JobOperator(
Expand Down Expand Up @@ -53,7 +50,7 @@ case class JobOperator(
val futureMappingCheck = Future {
checkAndCreateIndex(osClient, resultIndex)
}
val data = executeQuery(spark, query, dataSource, "", "")
val data = executeQuery(spark, query, dataSource, "", "", streaming)

val mappingCheckResult = ThreadUtils.awaitResult(futureMappingCheck, Duration(1, MINUTES))
dataToWrite = Some(mappingCheckResult match {
Expand Down Expand Up @@ -92,6 +89,8 @@ case class JobOperator(
try {
// Wait for streaming job complete if no error and there is streaming job running
if (!exceptionThrown && streaming && spark.streams.active.nonEmpty) {
// Clean Spark shuffle data after each microBatch.
spark.streams.addListener(new ShuffleCleaner(spark))
// wait if any child thread to finish before the main thread terminates
spark.streams.awaitAnyTermination()
}
Expand Down Expand Up @@ -149,4 +148,5 @@ case class JobOperator(
case false => incrementCounter(MetricConstants.STREAMING_SUCCESS_METRIC)
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql.util

import org.apache.spark.{MapOutputTrackerMaster, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.streaming.StreamingQueryListener

/**
* Clean Spark shuffle data after each microBatch.
* https://github.com/opensearch-project/opensearch-spark/issues/302
*/
class ShuffleCleaner(spark: SparkSession) extends StreamingQueryListener with Logging {

override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = {}

override def onQueryProgress(event: StreamingQueryListener.QueryProgressEvent): Unit = {
ShuffleCleaner.cleanUp(spark)
dai-chen marked this conversation as resolved.
Show resolved Hide resolved
}

override def onQueryTerminated(event: StreamingQueryListener.QueryTerminatedEvent): Unit = {}
}

trait Cleaner {
def cleanUp(spark: SparkSession)
}

object CleanerFactory {
def cleaner(streaming: Boolean): Cleaner = {
if (streaming) NoOpCleaner else ShuffleCleaner
}
}

/**
* No operation cleaner.
*/
object NoOpCleaner extends Cleaner {
override def cleanUp(spark: SparkSession): Unit = {}
}

/**
* Spark shuffle data cleaner.
*/
object ShuffleCleaner extends Cleaner with Logging {
def cleanUp(spark: SparkSession): Unit = {
logInfo("Before cleanUp Shuffle")
val cleaner = spark.sparkContext.cleaner
val masterTracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
val shuffleIds = masterTracker.shuffleStatuses.keys.toSet
shuffleIds.foreach(shuffleId => cleaner.foreach(c => c.doCleanupShuffle(shuffleId, true)))
logInfo("After cleanUp Shuffle")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ package org.apache.spark.sql

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.MockTimeProvider
import org.apache.spark.sql.util.{CleanerFactory, MockTimeProvider}

class FlintJobTest extends SparkFunSuite with JobMatchers {

Expand Down Expand Up @@ -76,7 +76,8 @@ class FlintJobTest extends SparkFunSuite with JobMatchers {
"select 1",
"20",
currentTime - queryRunTime,
new MockTimeProvider(currentTime))
new MockTimeProvider(currentTime),
CleanerFactory.cleaner(false))
assertEqualDataframe(expected, result)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql.util

import org.scalatest.matchers.should.Matchers

import org.apache.spark.SparkFunSuite

class CleanerFactoryTest extends SparkFunSuite with Matchers {

test("CleanerFactory should return NoOpCleaner when streaming is true") {
val cleaner = CleanerFactory.cleaner(streaming = true)
cleaner shouldBe NoOpCleaner
}

test("CleanerFactory should return ShuffleCleaner when streaming is false") {
val cleaner = CleanerFactory.cleaner(streaming = false)
cleaner shouldBe ShuffleCleaner
}
}
Loading