Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow user to remove broadcast variables when they are no longer used #771

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/src/main/scala/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ class SparkContext(
* Broadcast a read-only variable to the cluster, returning a [[spark.broadcast.Broadcast]] object for
* reading it in distributed functions. The variable will be sent to each cluster only once.
*/
def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal)
def broadcast[T](value: T, tellMaster: Boolean = true) = env.broadcastManager.newBroadcast[T](value, isLocal, tellMaster)

/**
* Add a file to be downloaded with this Spark job on every node.
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/spark/api/java/JavaSparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
* Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for
* reading it in distributed functions. The variable will be sent to each cluster only once.
*/
def broadcast[T](value: T): Broadcast[T] = sc.broadcast(value)
def broadcast[T](value: T, tellMaster: Boolean): Broadcast[T] = sc.broadcast(value, tellMaster)

/** Shut down the SparkContext. */
def stop() {
Expand Down
38 changes: 32 additions & 6 deletions core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import scala.math
import spark._
import spark.storage.StorageLevel

private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long, tellMaster: Boolean)
extends Broadcast[T](id)
with Logging
with Serializable {
Expand All @@ -21,7 +21,9 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
def blockId: String = "broadcast_" + id

MultiTracker.synchronized {
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
//If tellMaster is true, Let BlockManagerMaster know that we have the broadcast
//block for its latter notification us to remove.
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster)
}

@transient var arrayOfBlocks: Array[BroadcastBlock] = null
Expand Down Expand Up @@ -58,6 +60,27 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
if (!isLocal) {
sendBroadcast()
}

override def remove(toReleaseSource: Boolean = false) {
logInfo("Remove broadcast variable " + blockId)
if (tellMaster) {
logInfo("remove broadcast variable block" + blockId + " on slaves")
SparkEnv.get.blockManager.master.removeBlock(blockId)
}
SparkEnv.get.blockManager.removeBlock(blockId, false)
if (toReleaseSource) {
releaseSource()
}
}

def releaseSource(){
arrayOfBlocks = null
hasBlocksBitVector = null
numCopiesSent = null
listOfSources = null
serveMR = null
guideMR = null
}

def sendBroadcast() {
logInfo("Local host address: " + hostAddress)
Expand Down Expand Up @@ -116,7 +139,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
MultiTracker.synchronized {
SparkEnv.get.blockManager.getSingle(blockId) match {
SparkEnv.get.blockManager.getSingleLocal(blockId) match {
case Some(x) =>
value_ = x.asInstanceOf[T]

Expand All @@ -139,8 +162,11 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
val receptionSucceeded = receiveBroadcast(id)
if (receptionSucceeded) {
value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
//Let BlockManagerMaster know that we have the broadcast block for its latter notification us to remove.
//If tellMaster is true, Let BlockManagerMaster know that we have the broadcast
//block for its latter notification us to remove.
SparkEnv.get.blockManager.putSingle(
blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster)
} else {
logError("Reading broadcast variable " + id + " failed")
}
Expand Down Expand Up @@ -1033,8 +1059,8 @@ private[spark] class BitTorrentBroadcastFactory
extends BroadcastFactory {
def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) }

def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new BitTorrentBroadcast[T](value_, isLocal, id)
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long, tellMaster: Boolean) =
new BitTorrentBroadcast[T](value_, isLocal, id, tellMaster)

def stop() { MultiTracker.stop() }
}
12 changes: 8 additions & 4 deletions core/src/main/scala/spark/broadcast/Broadcast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ abstract class Broadcast[T](private[spark] val id: Long) extends Serializable {
// readObject having to be 'private' in sub-classes.

override def toString = "spark.Broadcast(" + id + ")"

// Remove a Broadcast blcok from the SparkContext and Executors that have it.
// Set isClearSource true to also remove the Broadcast value from its source.
def remove(toReleaseSource: Boolean)
}

private[spark]
Expand Down Expand Up @@ -45,9 +49,9 @@ class BroadcastManager(val _isDriver: Boolean) extends Logging with Serializable
}

private val nextBroadcastId = new AtomicLong(0)

def newBroadcast[T](value_ : T, isLocal: Boolean) =
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())

def newBroadcast[T](value_ : T, isLocal: Boolean, tellMaster: Boolean) =
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement(), tellMaster)
def isDriver = _isDriver
}
2 changes: 1 addition & 1 deletion core/src/main/scala/spark/broadcast/BroadcastFactory.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ package spark.broadcast
*/
private[spark] trait BroadcastFactory {
def initialize(isDriver: Boolean): Unit
def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T]
def newBroadcast[T](value: T, isLocal: Boolean, id: Long, tellMaster: Boolean): Broadcast[T]
def stop(): Unit
}
35 changes: 29 additions & 6 deletions core/src/main/scala/spark/broadcast/HttpBroadcast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,55 @@ import spark._
import spark.storage.StorageLevel
import util.{MetadataCleaner, TimeStampedHashSet}

private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long, tellMaster: Boolean = true)
extends Broadcast[T](id) with Logging with Serializable {

def value = value_

def blockId: String = "broadcast_" + id

HttpBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
//If tellMaster is true, Let BlockManagerMaster know that we have the broadcast
//block for its latter notification us to remove.
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster)
}

if (!isLocal) {
HttpBroadcast.write(id, value_)
}

override def remove(toReleaseSource: Boolean = false) {
logInfo("Remove broadcast variable " + blockId)
if (tellMaster) {
logInfo("remove broadcast variable block" + blockId + " on slaves")
SparkEnv.get.blockManager.master.removeBlock(blockId)
}
SparkEnv.get.blockManager.removeBlock(blockId, false)
if (toReleaseSource) {
releaseSource()
}
}

def releaseSource(){
val path: String = HttpBroadcast.broadcastDir + "/" + "broadcast-" + id
HttpBroadcast.files.internalMap.remove(path)
new File(path).delete()
logInfo("Deleted source broadcast file '" + path + "'")
}

// Called by JVM when deserializing an object
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
HttpBroadcast.synchronized {
SparkEnv.get.blockManager.getSingle(blockId) match {
SparkEnv.get.blockManager.getSingleLocal(blockId) match {
case Some(x) => value_ = x.asInstanceOf[T]
case None => {
logInfo("Started reading broadcast variable " + id)
val start = System.nanoTime
value_ = HttpBroadcast.read[T](id)
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
//If tellMaster is true, Let BlockManagerMaster know that we have the broadcast
//block for its latter notification us to remove.
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster)
val time = (System.nanoTime - start) / 1e9
logInfo("Reading broadcast variable " + id + " took " + time + " s")
}
Expand All @@ -50,8 +73,8 @@ extends Broadcast[T](id) with Logging with Serializable {
private[spark] class HttpBroadcastFactory extends BroadcastFactory {
def initialize(isDriver: Boolean) { HttpBroadcast.initialize(isDriver) }

def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new HttpBroadcast[T](value_, isLocal, id)
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long, tellMaster: Boolean) =
new HttpBroadcast[T](value_, isLocal, id, tellMaster)

def stop() { HttpBroadcast.stop() }
}
Expand Down
35 changes: 29 additions & 6 deletions core/src/main/scala/spark/broadcast/TreeBroadcast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@ import scala.math
import spark._
import spark.storage.StorageLevel

private[spark] class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
private[spark] class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long, tellMaster: Boolean)
extends Broadcast[T](id) with Logging with Serializable {

def value = value_

def blockId = "broadcast_" + id

MultiTracker.synchronized {
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
//If tellMaster is true, Let BlockManagerMaster know that we have the broadcast
//block for its latter notification us to remove.
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster)
}

@transient var arrayOfBlocks: Array[BroadcastBlock] = null
Expand Down Expand Up @@ -46,6 +48,25 @@ extends Broadcast[T](id) with Logging with Serializable {
if (!isLocal) {
sendBroadcast()
}

override def remove(toReleaseSource: Boolean = false) {
logInfo("Remove broadcast variable " + blockId)
if (tellMaster) {
logInfo("remove broadcast variable block" + blockId + " on slaves")
SparkEnv.get.blockManager.master.removeBlock(blockId)
}
SparkEnv.get.blockManager.removeBlock(blockId, false)
if (toReleaseSource) {
releaseSource()
}
}

def releaseSource(){
arrayOfBlocks = null
listOfSources = null
serveMR = null
guideMR = null
}

def sendBroadcast() {
logInfo("Local host address: " + hostAddress)
Expand Down Expand Up @@ -92,7 +113,7 @@ extends Broadcast[T](id) with Logging with Serializable {
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
MultiTracker.synchronized {
SparkEnv.get.blockManager.getSingle(blockId) match {
SparkEnv.get.blockManager.getSingleLocal(blockId) match {
case Some(x) =>
value_ = x.asInstanceOf[T]

Expand All @@ -114,8 +135,10 @@ extends Broadcast[T](id) with Logging with Serializable {
val receptionSucceeded = receiveBroadcast(id)
if (receptionSucceeded) {
value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
//If tellMaster is true, Let BlockManagerMaster know that we have the broadcast
//block for its latter notification us to remove.
SparkEnv.get.blockManager.putSingle(
blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster)
} else {
logError("Reading broadcast variable " + id + " failed")
}
Expand Down Expand Up @@ -578,8 +601,8 @@ private[spark] class TreeBroadcastFactory
extends BroadcastFactory {
def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) }

def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new TreeBroadcast[T](value_, isLocal, id)
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long, tellMaster: Boolean) =
new TreeBroadcast[T](value_, isLocal, id, tellMaster)

def stop() { MultiTracker.stop() }
}
7 changes: 7 additions & 0 deletions core/src/main/scala/spark/storage/BlockManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,13 @@ private[spark] class BlockManager(
def getSingle(blockId: String): Option[Any] = {
get(blockId).map(_.next())
}

/**
* Read a block consisting of a single object only from local BlockManager.
*/
def getSingleLocal(blockId: String): Option[Any] = {
getLocal(blockId).map(_.next())
}

/**
* Write a block consisting of a single object.
Expand Down
3 changes: 2 additions & 1 deletion examples/src/main/scala/spark/examples/BroadcastTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ object BroadcastTest {
for (i <- 0 until 2) {
println("Iteration " + i)
println("===========")
val barr1 = sc.broadcast(arr1)
val barr1 = sc.broadcast(arr1, (i == 0))
sc.parallelize(1 to 10, slices).foreach {
i => println(barr1.value.size)
}
barr1.remove(true)
}

System.exit(0)
Expand Down