-
Notifications
You must be signed in to change notification settings - Fork 44
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
311d071
commit a40d656
Showing
4 changed files
with
83 additions
and
56 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
44 changes: 38 additions & 6 deletions
44
rainier-core/src/main/scala/com/stripe/rainier/sampler/Ehmc.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,44 +1,76 @@ | ||
package com.stripe.rainier.sampler | ||
|
||
import scala.collection.mutable.ListBuffer | ||
import Log._ | ||
import java.util.concurrent.TimeUnit._ | ||
|
||
/** | ||
* Empirical HMC - automated tuning of step size and number of leapfrog steps | ||
* @param l0 the initial number of leapfrog steps to use during the dual averaging phase | ||
* @param k the number of iterations to use when determining the | ||
* empirical distribution of the total number of leapfrog steps until a u-turn | ||
*/ | ||
final case class EHMC(l0: Int, k: Int = 2000) extends Sampler { | ||
final case class EHMC(l0: Int, k: Int) extends Sampler { | ||
|
||
def sample(density: DensityFunction, | ||
warmupIterations: Int, | ||
iterations: Int, | ||
keepEvery: Int)(implicit rng: RNG): List[Array[Double]] = { | ||
|
||
val lf = LeapFrog(density) | ||
val params = lf.initialize | ||
|
||
FINE.log("Finding reasonable initial step size") | ||
val stepSize0 = DualAvg.findReasonableStepSize(lf, params) | ||
FINE.log("Found initial step size of %f", stepSize0) | ||
|
||
FINE.log("Warming up for %d iterations", warmupIterations) | ||
val stepSize = | ||
DualAvg.findStepSize(lf, params, 0.65, l0, warmupIterations) | ||
DualAvg.findStepSize(0.65, stepSize0, warmupIterations) { ss => | ||
lf.step(params, 1, ss) | ||
} | ||
FINE.log("Found step size of %f", stepSize) | ||
|
||
val empiricalL: Vector[Int] = | ||
lf.empiricalLongestSteps(params, l0, k, stepSize) | ||
FINE.log("Sampling %d path lengths", k) | ||
val empiricalL = Vector.fill(k) { | ||
lf.longestBatchStep(params, l0, stepSize)._2 | ||
} | ||
val sorted = empiricalL.toList.sorted | ||
FINE.log("Using a range of %d to %d steps", sorted.head, sorted.last) | ||
|
||
if (stepSize == 0.0) | ||
List(lf.variables(params)) | ||
else { | ||
val buf = new ListBuffer[Array[Double]] | ||
var i = 0 | ||
|
||
while (i < iterations) { | ||
FINE.log("Sampling for %d iterations", iterations) | ||
|
||
var acceptSum = 0.0 | ||
while (i < iterations) { | ||
val j = rng.int(k) | ||
val nSteps = empiricalL(j) | ||
|
||
lf.step(params, nSteps, stepSize) | ||
FINER | ||
.atMostEvery(1, SECONDS) | ||
.log("Sampling iteration %d of %d for %d steps, acceptance rate %f", | ||
i, | ||
iterations, | ||
nSteps, | ||
(acceptSum / i)) | ||
|
||
val logAccept = lf.step(params, nSteps, stepSize) | ||
acceptSum += Math.exp(logAccept) | ||
if (i % keepEvery == 0) | ||
buf += lf.variables(params) | ||
i += 1 | ||
} | ||
FINE.log("Finished sampling, acceptance rate %f", (acceptSum / i)) | ||
buf.toList | ||
} | ||
} | ||
} | ||
|
||
object EHMC { | ||
def apply(l0: Int): EHMC = EHMC(l0, 100) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters