Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean shuffle data #312

Merged
merged 3 commits into from
Apr 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -6,24 +6,15 @@
// defined in spark package so that I can use ThreadUtils
package org.apache.spark.sql

import java.util.Locale
import java.util.concurrent.atomic.AtomicInteger

import org.opensearch.client.{RequestOptions, RestHighLevelClient}
import org.opensearch.cluster.metadata.MappingMetadata
import org.opensearch.common.settings.Settings
import org.opensearch.common.xcontent.XContentType
import org.opensearch.flint.core.{FlintClient, FlintClientBuilder, FlintOptions}
import org.opensearch.flint.core.metadata.FlintMetadata
import org.opensearch.flint.core.metrics.MetricConstants
import org.opensearch.flint.core.metrics.MetricsUtil.registerGauge
import play.api.libs.json._

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.types.{StructField, _}
import org.apache.spark.sql.types._

/**
* Spark SQL Application entrypoint
Original file line number Diff line number Diff line change
@@ -7,26 +7,18 @@ package org.apache.spark.sql

import java.util.Locale

import scala.concurrent.{ExecutionContext, Future, TimeoutException}
import scala.concurrent.duration.{Duration, MINUTES}

import com.amazonaws.services.s3.model.AmazonS3Exception
import org.apache.commons.text.StringEscapeUtils.unescapeJava
import org.opensearch.flint.core.{FlintClient, IRestHighLevelClient}
import org.opensearch.flint.core.metadata.FlintMetadata
import org.opensearch.flint.core.IRestHighLevelClient
import org.opensearch.flint.core.metrics.MetricConstants
import org.opensearch.flint.core.metrics.MetricsUtil.incrementCounter
import play.api.libs.json.{JsArray, JsBoolean, JsObject, Json, JsString, JsValue}
import play.api.libs.json._

import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.FlintREPL.envinromentProvider
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.types.{ArrayType, LongType, StringType, StructField, StructType}
import org.apache.spark.sql.util.{DefaultThreadPoolFactory, EnvironmentProvider, RealEnvironment, RealTimeProvider, ThreadPoolFactory, TimeProvider}
import org.apache.spark.util.ThreadUtils
import org.apache.spark.sql.types._
import org.apache.spark.sql.util._

trait FlintJobExecutor {
this: Logging =>
@@ -157,7 +149,8 @@ trait FlintJobExecutor {
query: String,
sessionId: String,
startTime: Long,
timeProvider: TimeProvider): DataFrame = {
timeProvider: TimeProvider,
cleaner: Cleaner): DataFrame = {
// Create the schema dataframe
val schemaRows = result.schema.fields.map { field =>
Row(field.name, field.dataType.typeName)
@@ -192,6 +185,11 @@ trait FlintJobExecutor {
val resultSchemaToSave = resultSchema.toJSON.collect.toList.map(_.replaceAll("\"", "'"))
val endTime = timeProvider.currentEpochMillis()

// https://github.com/opensearch-project/opensearch-spark/issues/302. Clean shuffle data
// after consumed the query result. Streaming query shuffle data is cleaned after each
// microBatch execution.
cleaner.cleanUp(spark)

// Create the data rows
val rows = Seq(
(
@@ -375,7 +373,8 @@ trait FlintJobExecutor {
query: String,
dataSource: String,
queryId: String,
sessionId: String): DataFrame = {
sessionId: String,
streaming: Boolean): DataFrame = {
// Execute SQL query
val startTime = System.currentTimeMillis()
// we have to set job group in the same thread that started the query according to spark doc
@@ -390,7 +389,8 @@ trait FlintJobExecutor {
query,
sessionId,
startTime,
currentTimeProvider)
currentTimeProvider,
CleanerFactory.cleaner(streaming))
}

private def handleQueryException(
Original file line number Diff line number Diff line change
@@ -22,15 +22,13 @@ import org.opensearch.flint.app.{FlintCommand, FlintInstance}
import org.opensearch.flint.app.FlintInstance.formats
import org.opensearch.flint.core.FlintOptions
import org.opensearch.flint.core.metrics.MetricConstants
import org.opensearch.flint.core.metrics.MetricsUtil.{decrementCounter, getTimerContext, incrementCounter, registerGauge, stopTimer}
import org.opensearch.flint.core.metrics.MetricsUtil.{getTimerContext, incrementCounter, registerGauge, stopTimer}
import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater}
import org.opensearch.search.sort.SortOrder

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.sql.FlintJob.createSparkSession
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.flint.config.FlintSparkConf.REPL_INACTIVITY_TIMEOUT_MILLIS
import org.apache.spark.sql.util.{DefaultShutdownHookManager, ShutdownHookManagerTrait}
import org.apache.spark.util.ThreadUtils

@@ -829,7 +827,13 @@ object FlintREPL extends Logging with FlintJobExecutor {
startTime)
} else {
val futureQueryExecution = Future {
executeQuery(spark, flintCommand.query, dataSource, flintCommand.queryId, sessionId)
executeQuery(
spark,
flintCommand.query,
dataSource,
flintCommand.queryId,
sessionId,
false)
}(executionContext)

// time out after 10 minutes
Original file line number Diff line number Diff line change
@@ -14,13 +14,10 @@ import scala.util.{Failure, Success, Try}

import org.opensearch.flint.core.metrics.MetricConstants
import org.opensearch.flint.core.metrics.MetricsUtil.incrementCounter
import org.opensearch.flint.core.storage.OpenSearchUpdater

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.sql.FlintJob.createSparkSession
import org.apache.spark.sql.FlintREPL.{executeQuery, logInfo, threadPoolFactory, updateFlintInstanceBeforeShutdown}
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.util.ShuffleCleaner
import org.apache.spark.util.ThreadUtils

case class JobOperator(
@@ -53,7 +50,7 @@ case class JobOperator(
val futureMappingCheck = Future {
checkAndCreateIndex(osClient, resultIndex)
}
val data = executeQuery(spark, query, dataSource, "", "")
val data = executeQuery(spark, query, dataSource, "", "", streaming)

val mappingCheckResult = ThreadUtils.awaitResult(futureMappingCheck, Duration(1, MINUTES))
dataToWrite = Some(mappingCheckResult match {
@@ -92,6 +89,8 @@ case class JobOperator(
try {
// Wait for streaming job complete if no error and there is streaming job running
if (!exceptionThrown && streaming && spark.streams.active.nonEmpty) {
// Clean Spark shuffle data after each microBatch.
spark.streams.addListener(new ShuffleCleaner(spark))
// wait if any child thread to finish before the main thread terminates
spark.streams.awaitAnyTermination()
}
@@ -149,4 +148,5 @@ case class JobOperator(
case false => incrementCounter(MetricConstants.STREAMING_SUCCESS_METRIC)
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql.util

import org.apache.spark.{MapOutputTrackerMaster, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.streaming.StreamingQueryListener

/**
* Clean Spark shuffle data after each microBatch.
* https://github.com/opensearch-project/opensearch-spark/issues/302
*/
class ShuffleCleaner(spark: SparkSession) extends StreamingQueryListener with Logging {

override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = {}

override def onQueryProgress(event: StreamingQueryListener.QueryProgressEvent): Unit = {
ShuffleCleaner.cleanUp(spark)
dai-chen marked this conversation as resolved.
Show resolved Hide resolved
}

override def onQueryTerminated(event: StreamingQueryListener.QueryTerminatedEvent): Unit = {}
}

trait Cleaner {
def cleanUp(spark: SparkSession)
}

object CleanerFactory {
def cleaner(streaming: Boolean): Cleaner = {
if (streaming) NoOpCleaner else ShuffleCleaner
}
}

/**
* No operation cleaner.
*/
object NoOpCleaner extends Cleaner {
override def cleanUp(spark: SparkSession): Unit = {}
}

/**
* Spark shuffle data cleaner.
*/
object ShuffleCleaner extends Cleaner with Logging {
def cleanUp(spark: SparkSession): Unit = {
logInfo("Before cleanUp Shuffle")
val cleaner = spark.sparkContext.cleaner
val masterTracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
val shuffleIds = masterTracker.shuffleStatuses.keys.toSet
shuffleIds.foreach(shuffleId => cleaner.foreach(c => c.doCleanupShuffle(shuffleId, true)))
logInfo("After cleanUp Shuffle")
}
}
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@ package org.apache.spark.sql

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.MockTimeProvider
import org.apache.spark.sql.util.{CleanerFactory, MockTimeProvider}

class FlintJobTest extends SparkFunSuite with JobMatchers {

@@ -76,7 +76,8 @@ class FlintJobTest extends SparkFunSuite with JobMatchers {
"select 1",
"20",
currentTime - queryRunTime,
new MockTimeProvider(currentTime))
new MockTimeProvider(currentTime),
CleanerFactory.cleaner(false))
assertEqualDataframe(expected, result)
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql.util

import org.scalatest.matchers.should.Matchers

import org.apache.spark.SparkFunSuite

class CleanerFactoryTest extends SparkFunSuite with Matchers {

test("CleanerFactory should return NoOpCleaner when streaming is true") {
val cleaner = CleanerFactory.cleaner(streaming = true)
cleaner shouldBe NoOpCleaner
}

test("CleanerFactory should return ShuffleCleaner when streaming is false") {
val cleaner = CleanerFactory.cleaner(streaming = false)
cleaner shouldBe ShuffleCleaner
}
}