Skip to content

Commit

Permalink
Address review comments & convert variable to getter
Browse files Browse the repository at this point in the history
Signed-off-by: Partho Sarthi <[email protected]>
  • Loading branch information
parthosa committed Nov 8, 2024
1 parent 59d8220 commit 2baefe2
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import scala.collection.Map

import org.apache.spark.resource.{ExecutorResourceRequest, TaskResourceRequest}
import org.apache.spark.sql.rapids.tool.store.AccumMetaRef
import org.apache.spark.sql.rapids.tool.util.StringUtils
import org.apache.spark.sql.rapids.tool.util.{SparkRuntime, StringUtils}

/**
* This is a warehouse to store all Classes
Expand Down Expand Up @@ -292,7 +292,7 @@ case class UnsupportedOpsProfileResult(appIndex: Int,
case class AppInfoProfileResults(appIndex: Int, appName: String,
appId: Option[String], sparkUser: String,
startTime: Long, endTime: Option[Long], duration: Option[Long],
durationStr: String, sparkRuntime: String, sparkVersion: String,
durationStr: String, sparkRuntime: SparkRuntime.SparkRuntime, sparkVersion: String,
pluginEnabled: Boolean) extends ProfileResult {
override val outputHeaders = Seq("appIndex", "appName", "appId",
"sparkUser", "startTime", "endTime", "duration", "durationStr",
Expand All @@ -315,14 +315,14 @@ case class AppInfoProfileResults(appIndex: Int, appName: String,
override def convertToSeq: Seq[String] = {
Seq(appIndex.toString, appName, appId.getOrElse(""),
sparkUser, startTime.toString, endTimeToStr, durToStr,
durationStr, sparkRuntime, sparkVersion, pluginEnabled.toString)
durationStr, sparkRuntime.toString, sparkVersion, pluginEnabled.toString)
}
override def convertToCSVSeq: Seq[String] = {
Seq(appIndex.toString, StringUtils.reformatCSVString(appName),
StringUtils.reformatCSVString(appId.getOrElse("")), StringUtils.reformatCSVString(sparkUser),
startTime.toString, endTimeToStr, durToStr, StringUtils.reformatCSVString(durationStr),
StringUtils.reformatCSVString(sparkRuntime), StringUtils.reformatCSVString(sparkVersion),
pluginEnabled.toString)
StringUtils.reformatCSVString(sparkRuntime.toString),
StringUtils.reformatCSVString(sparkVersion), pluginEnabled.toString)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ trait AppInformationViewTrait extends ViewableTrait[AppInfoProfileResults] {
app.appMetaData.map { a =>
AppInfoProfileResults(index, a.appName, a.appId,
a.sparkUser, a.startTime, a.endTime, app.getAppDuration,
a.getDurationString, app.sparkRuntime.toString, app.sparkVersion, app.gpuMode)
a.getDurationString, app.getSparkRuntime, app.sparkVersion, app.gpuMode)
}.toSeq
}
override def sortView(rows: Seq[AppInfoProfileResults]): Seq[AppInfoProfileResults] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.scheduler.{SparkListenerEvent, StageInfo}
import org.apache.spark.sql.execution.SparkPlanInfo
import org.apache.spark.sql.execution.ui.SparkPlanGraphNode
import org.apache.spark.sql.rapids.tool.store.{AccumManager, DataSourceRecord, SQLPlanModelManager, StageModel, StageModelManager, TaskModelManager}
import org.apache.spark.sql.rapids.tool.util.{EventUtils, RapidsToolsConfUtil, SparkRuntime, ToolsPlanGraph, UTF8Source}
import org.apache.spark.sql.rapids.tool.util.{EventUtils, RapidsToolsConfUtil, ToolsPlanGraph, UTF8Source}
import org.apache.spark.util.Utils

abstract class AppBase(
Expand Down Expand Up @@ -475,7 +475,6 @@ abstract class AppBase(
protected def postCompletion(): Unit = {
registerAttemptId()
calculateAppDuration()
setSparkRuntime()
}

/**
Expand All @@ -486,19 +485,6 @@ abstract class AppBase(
processEventsInternal()
postCompletion()
}

/**
* Sets the spark runtime based on the properties of the application.
*/
private def setSparkRuntime(): Unit = {
sparkRuntime = if (isPhoton) {
SparkRuntime.PHOTON
} else if (gpuMode) {
SparkRuntime.SPARK_RAPIDS
} else {
SparkRuntime.SPARK
}
}
}

object AppBase {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@ trait ClusterTagPropHandler extends CacheablePropsHandler {
var clusterTagClusterId: String = ""
var clusterTagClusterName: String = ""

// A flag to indicate whether the eventlog being processed is an eventlog from Photon.
var isPhoton = false

// Flag used to indicate that the App was a Databricks App.
def isDatabricks: Boolean = {
clusterTags.nonEmpty && clusterTagClusterId.nonEmpty && clusterTagClusterName.nonEmpty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,42 @@ import org.apache.spark.util.Utils.REDACTION_REPLACEMENT_TEXT


/**
* Enum to represent different spark runtimes.
* SparkRuntime enumeration is used to identify the specific runtime environment
* in which the application is being executed.
*/
object SparkRuntime extends Enumeration {
type SparkRuntime = Value
val SPARK, SPARK_RAPIDS, PHOTON = Value

/**
* Represents the default Apache Spark runtime environment.
*/
val SPARK: SparkRuntime = Value

/**
* Represents the Spark RAPIDS runtime environment.
*/
val SPARK_RAPIDS: SparkRuntime = Value

/**
* Represents the Photon runtime environment on Databricks.
*/
val PHOTON: SparkRuntime = Value

/**
* Returns the SparkRuntime value based on the given parameters.
* @param isPhoton Boolean flag indicating whether the application is running on Photon.
* @param isGpu Boolean flag indicating whether the application is running on GPU.
* @return
*/
def getRuntime(isPhoton: Boolean, isGpu: Boolean): SparkRuntime.SparkRuntime = {
if (isPhoton) {
PHOTON
} else if (isGpu) {
SPARK_RAPIDS
} else {
SPARK
}
}
}

// Handles updating and caching Spark Properties for a Spark application.
Expand Down Expand Up @@ -77,9 +108,10 @@ trait CacheablePropsHandler {

// caches the spark-version from the eventlogs
var sparkVersion: String = ""
// caches the spark runtime based on the application properties
var sparkRuntime: SparkRuntime.Value = SparkRuntime.SPARK
// A flag to indicate whether the eventlog is an eventlog with Spark RAPIDS runtime.
var gpuMode = false
// A flag to indicate whether the eventlog is an eventlog from Photon runtime.
var isPhoton = false
// A flag whether hive is enabled or not. Note that we assume that the
// property is global to the entire application once it is set. a.k.a, it cannot be disabled
// once it was set to true.
Expand Down Expand Up @@ -143,4 +175,12 @@ trait CacheablePropsHandler {
def isGPUModeEnabledForJob(event: SparkListenerJobStart): Boolean = {
gpuMode || ProfileUtils.isPluginEnabled(event.properties.asScala)
}

/**
* Returns the SparkRuntime environment in which the application is being executed.
* This is calculated based on other cached properties.
*/
def getSparkRuntime: SparkRuntime.SparkRuntime = {
SparkRuntime.getRuntime(isPhoton, gpuMode)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1119,14 +1119,14 @@ class ApplicationInfoSuite extends FunSuite with Logging {
val sparkRuntimeTestCases: Seq[(SparkRuntime.Value, String)] = Seq(
SparkRuntime.SPARK -> s"$qualLogDir/nds_q86_test",
SparkRuntime.SPARK_RAPIDS -> s"$logDir/nds_q66_gpu.zstd",
SparkRuntime.PHOTON-> s"$qualLogDir/nds_q88_photon_db_13_3.zstd"
SparkRuntime.PHOTON -> s"$qualLogDir/nds_q88_photon_db_13_3.zstd"
)

sparkRuntimeTestCases.foreach { case (expectedSparkRuntime, eventLog) =>
test(s"test spark runtime property for ${expectedSparkRuntime.toString} eventlog") {
val apps = ToolTestUtils.processProfileApps(Array(eventLog), sparkSession)
assert(apps.size == 1)
assert(apps.head.sparkRuntime == expectedSparkRuntime)
assert(apps.head.getSparkRuntime == expectedSparkRuntime)
}
}
}

0 comments on commit 2baefe2

Please sign in to comment.