diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 29022c7419b4b..0a66cc974da7c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -53,6 +53,8 @@ private[deploy] class Master( private val forwardMessageThread = ThreadUtils.newDaemonSingleThreadScheduledExecutor("master-forward-message-thread") + private val driverIdPattern = conf.get(DRIVER_ID_PATTERN) + // For application IDs private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) @@ -1175,7 +1177,7 @@ private[deploy] class Master( } private def newDriverId(submitDate: Date): String = { - val appId = "driver-%s-%04d".format(createDateFormat.format(submitDate), nextDriverNumber) + val appId = driverIdPattern.format(createDateFormat.format(submitDate), nextDriverNumber) nextDriverNumber += 1 appId } diff --git a/core/src/main/scala/org/apache/spark/internal/config/Deploy.scala b/core/src/main/scala/org/apache/spark/internal/config/Deploy.scala index aaeb37a17249a..bffdc79175bd9 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/Deploy.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/Deploy.scala @@ -82,4 +82,12 @@ private[spark] object Deploy { .checkValue(_ > 0, "The maximum number of running drivers should be positive.") .createWithDefault(Int.MaxValue) + val DRIVER_ID_PATTERN = ConfigBuilder("spark.deploy.driverIdPattern") + .doc("The pattern for driver ID generation based on Java `String.format` method. " + + "The default value is `driver-%s-%04d` which represents the existing driver id string " + + ", e.g., `driver-20231031224459-0019`. Please be careful to generate unique IDs") + .version("4.0.0") + .stringConf + .checkValue(!_.format("20231101000000", 0).exists(_.isWhitespace), "Whitespace is not allowed.") + .createWithDefault("driver-%s-%04d") } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index fc6c7d267e6a5..cef0e84f20f7a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -802,6 +802,7 @@ class MasterSuite extends SparkFunSuite private val _waitingDrivers = PrivateMethod[mutable.ArrayBuffer[DriverInfo]](Symbol("waitingDrivers")) private val _state = PrivateMethod[RecoveryState.Value](Symbol("state")) + private val _newDriverId = PrivateMethod[String](Symbol("newDriverId")) private val workerInfo = makeWorkerInfo(4096, 10) private val workerInfos = Array(workerInfo, workerInfo, workerInfo) @@ -1236,6 +1237,20 @@ class MasterSuite extends SparkFunSuite private def getState(master: Master): RecoveryState.Value = { master.invokePrivate(_state()) } + + test("SPARK-45753: Support driver id pattern") { + val master = makeMaster(new SparkConf().set(DRIVER_ID_PATTERN, "my-driver-%2$05d")) + val submitDate = new Date() + assert(master.invokePrivate(_newDriverId(submitDate)) === "my-driver-00000") + assert(master.invokePrivate(_newDriverId(submitDate)) === "my-driver-00001") + } + + test("SPARK-45753: Prevent invalid driver id patterns") { + val m = intercept[IllegalArgumentException] { + makeMaster(new SparkConf().set(DRIVER_ID_PATTERN, "my driver")) + }.getMessage + assert(m.contains("Whitespace is not allowed")) + } } private class FakeRecoveryModeFactory(conf: SparkConf, ser: serializer.Serializer)