diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/ProfileClassWarehouse.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/ProfileClassWarehouse.scala index b853dc482..24b6cbd73 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/ProfileClassWarehouse.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/ProfileClassWarehouse.scala @@ -20,7 +20,7 @@ import scala.collection.Map import org.apache.spark.resource.{ExecutorResourceRequest, TaskResourceRequest} import org.apache.spark.sql.rapids.tool.store.AccumMetaRef -import org.apache.spark.sql.rapids.tool.util.StringUtils +import org.apache.spark.sql.rapids.tool.util.{SparkRuntime, StringUtils} /** * This is a warehouse to store all Classes @@ -292,7 +292,7 @@ case class UnsupportedOpsProfileResult(appIndex: Int, case class AppInfoProfileResults(appIndex: Int, appName: String, appId: Option[String], sparkUser: String, startTime: Long, endTime: Option[Long], duration: Option[Long], - durationStr: String, sparkRuntime: String, sparkVersion: String, + durationStr: String, sparkRuntime: SparkRuntime.SparkRuntime, sparkVersion: String, pluginEnabled: Boolean) extends ProfileResult { override val outputHeaders = Seq("appIndex", "appName", "appId", "sparkUser", "startTime", "endTime", "duration", "durationStr", @@ -315,14 +315,14 @@ case class AppInfoProfileResults(appIndex: Int, appName: String, override def convertToSeq: Seq[String] = { Seq(appIndex.toString, appName, appId.getOrElse(""), sparkUser, startTime.toString, endTimeToStr, durToStr, - durationStr, sparkRuntime, sparkVersion, pluginEnabled.toString) + durationStr, sparkRuntime.toString, sparkVersion, pluginEnabled.toString) } override def convertToCSVSeq: Seq[String] = { Seq(appIndex.toString, StringUtils.reformatCSVString(appName), StringUtils.reformatCSVString(appId.getOrElse("")), StringUtils.reformatCSVString(sparkUser), startTime.toString, endTimeToStr, durToStr, StringUtils.reformatCSVString(durationStr), - StringUtils.reformatCSVString(sparkRuntime), StringUtils.reformatCSVString(sparkVersion), - pluginEnabled.toString) + StringUtils.reformatCSVString(sparkRuntime.toString), + StringUtils.reformatCSVString(sparkVersion), pluginEnabled.toString) } } diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/views/InformationView.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/views/InformationView.scala index 66277551a..7a7484c38 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/views/InformationView.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/views/InformationView.scala @@ -29,7 +29,7 @@ trait AppInformationViewTrait extends ViewableTrait[AppInfoProfileResults] { app.appMetaData.map { a => AppInfoProfileResults(index, a.appName, a.appId, a.sparkUser, a.startTime, a.endTime, app.getAppDuration, - a.getDurationString, app.sparkRuntime.toString, app.sparkVersion, app.gpuMode) + a.getDurationString, app.getSparkRuntime, app.sparkVersion, app.gpuMode) }.toSeq } override def sortView(rows: Seq[AppInfoProfileResults]): Seq[AppInfoProfileResults] = { diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/AppBase.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/AppBase.scala index da70afd91..e3313b832 100644 --- a/core/src/main/scala/org/apache/spark/sql/rapids/tool/AppBase.scala +++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/AppBase.scala @@ -37,7 +37,7 @@ import org.apache.spark.scheduler.{SparkListenerEvent, StageInfo} import org.apache.spark.sql.execution.SparkPlanInfo import org.apache.spark.sql.execution.ui.SparkPlanGraphNode import org.apache.spark.sql.rapids.tool.store.{AccumManager, DataSourceRecord, SQLPlanModelManager, StageModel, StageModelManager, TaskModelManager} -import org.apache.spark.sql.rapids.tool.util.{EventUtils, RapidsToolsConfUtil, SparkRuntime, ToolsPlanGraph, UTF8Source} +import org.apache.spark.sql.rapids.tool.util.{EventUtils, RapidsToolsConfUtil, ToolsPlanGraph, UTF8Source} import org.apache.spark.util.Utils abstract class AppBase( @@ -475,7 +475,6 @@ abstract class AppBase( protected def postCompletion(): Unit = { registerAttemptId() calculateAppDuration() - setSparkRuntime() } /** @@ -486,19 +485,6 @@ abstract class AppBase( processEventsInternal() postCompletion() } - - /** - * Sets the spark runtime based on the properties of the application. - */ - private def setSparkRuntime(): Unit = { - sparkRuntime = if (isPhoton) { - SparkRuntime.PHOTON - } else if (gpuMode) { - SparkRuntime.SPARK_RAPIDS - } else { - SparkRuntime.SPARK - } - } } object AppBase { diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/ClusterTagPropHandler.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/ClusterTagPropHandler.scala index 4377ef0e5..aebab1082 100644 --- a/core/src/main/scala/org/apache/spark/sql/rapids/tool/ClusterTagPropHandler.scala +++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/ClusterTagPropHandler.scala @@ -28,9 +28,6 @@ trait ClusterTagPropHandler extends CacheablePropsHandler { var clusterTagClusterId: String = "" var clusterTagClusterName: String = "" - // A flag to indicate whether the eventlog being processed is an eventlog from Photon. - var isPhoton = false - // Flag used to indicate that the App was a Databricks App. def isDatabricks: Boolean = { clusterTags.nonEmpty && clusterTagClusterId.nonEmpty && clusterTagClusterName.nonEmpty diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/util/CacheablePropsHandler.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/util/CacheablePropsHandler.scala index 9c3f3abf7..50b6763ff 100644 --- a/core/src/main/scala/org/apache/spark/sql/rapids/tool/util/CacheablePropsHandler.scala +++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/util/CacheablePropsHandler.scala @@ -27,11 +27,42 @@ import org.apache.spark.util.Utils.REDACTION_REPLACEMENT_TEXT /** - * Enum to represent different spark runtimes. + * SparkRuntime enumeration is used to identify the specific runtime environment + * in which the application is being executed. */ object SparkRuntime extends Enumeration { type SparkRuntime = Value - val SPARK, SPARK_RAPIDS, PHOTON = Value + + /** + * Represents the default Apache Spark runtime environment. + */ + val SPARK: SparkRuntime = Value + + /** + * Represents the Spark RAPIDS runtime environment. + */ + val SPARK_RAPIDS: SparkRuntime = Value + + /** + * Represents the Photon runtime environment on Databricks. + */ + val PHOTON: SparkRuntime = Value + + /** + * Returns the SparkRuntime value based on the given parameters. + * @param isPhoton Boolean flag indicating whether the application is running on Photon. + * @param isGpu Boolean flag indicating whether the application is running on GPU. + * @return + */ + def getRuntime(isPhoton: Boolean, isGpu: Boolean): SparkRuntime.SparkRuntime = { + if (isPhoton) { + PHOTON + } else if (isGpu) { + SPARK_RAPIDS + } else { + SPARK + } + } } // Handles updating and caching Spark Properties for a Spark application. @@ -77,9 +108,10 @@ trait CacheablePropsHandler { // caches the spark-version from the eventlogs var sparkVersion: String = "" - // caches the spark runtime based on the application properties - var sparkRuntime: SparkRuntime.Value = SparkRuntime.SPARK + // A flag to indicate whether the eventlog is an eventlog with Spark RAPIDS runtime. var gpuMode = false + // A flag to indicate whether the eventlog is an eventlog from Photon runtime. + var isPhoton = false // A flag whether hive is enabled or not. Note that we assume that the // property is global to the entire application once it is set. a.k.a, it cannot be disabled // once it was set to true. @@ -143,4 +175,12 @@ trait CacheablePropsHandler { def isGPUModeEnabledForJob(event: SparkListenerJobStart): Boolean = { gpuMode || ProfileUtils.isPluginEnabled(event.properties.asScala) } + + /** + * Returns the SparkRuntime environment in which the application is being executed. + * This is calculated based on other cached properties. + */ + def getSparkRuntime: SparkRuntime.SparkRuntime = { + SparkRuntime.getRuntime(isPhoton, gpuMode) + } } diff --git a/core/src/test/scala/com/nvidia/spark/rapids/tool/profiling/ApplicationInfoSuite.scala b/core/src/test/scala/com/nvidia/spark/rapids/tool/profiling/ApplicationInfoSuite.scala index 2f2436fe4..de9921cec 100644 --- a/core/src/test/scala/com/nvidia/spark/rapids/tool/profiling/ApplicationInfoSuite.scala +++ b/core/src/test/scala/com/nvidia/spark/rapids/tool/profiling/ApplicationInfoSuite.scala @@ -1119,14 +1119,14 @@ class ApplicationInfoSuite extends FunSuite with Logging { val sparkRuntimeTestCases: Seq[(SparkRuntime.Value, String)] = Seq( SparkRuntime.SPARK -> s"$qualLogDir/nds_q86_test", SparkRuntime.SPARK_RAPIDS -> s"$logDir/nds_q66_gpu.zstd", - SparkRuntime.PHOTON-> s"$qualLogDir/nds_q88_photon_db_13_3.zstd" + SparkRuntime.PHOTON -> s"$qualLogDir/nds_q88_photon_db_13_3.zstd" ) sparkRuntimeTestCases.foreach { case (expectedSparkRuntime, eventLog) => test(s"test spark runtime property for ${expectedSparkRuntime.toString} eventlog") { val apps = ToolTestUtils.processProfileApps(Array(eventLog), sparkSession) assert(apps.size == 1) - assert(apps.head.sparkRuntime == expectedSparkRuntime) + assert(apps.head.getSparkRuntime == expectedSparkRuntime) } } }