Skip to content

Commit

Permalink
cleanups and logging for hmc (#323)
Browse files Browse the repository at this point in the history
  • Loading branch information
avi-stripe authored Feb 9, 2019
1 parent 311d071 commit a40d656
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ final private class DualAvg(
delta: Double,
var logStepSize: Double,
var logStepSizeBar: Double,
var avgAcceptanceProb: Double,
var avgError: Double,
var iteration: Int,
shrinkageTarget: Double,
stepSizeUpdateDenom: Double = 0.05,
Expand All @@ -20,27 +20,22 @@ final private class DualAvg(
def update(logAcceptanceProb: Double): Unit = {
val newAcceptanceProb = Math.exp(logAcceptanceProb)
iteration = iteration + 1
val avgAcceptanceProbMultiplier =
val avgErrorMultiplier =
1.0 / (iteration.toDouble + acceptanceProbUpdateDenom)
val stepSizeMultiplier = Math.pow(iteration.toDouble, -decayRate)

avgAcceptanceProb = (
(1.0 - avgAcceptanceProbMultiplier) * avgAcceptanceProb
+ (avgAcceptanceProbMultiplier * (delta - newAcceptanceProb))
avgError = (
(1.0 - avgErrorMultiplier) * avgError
+ (avgErrorMultiplier * (delta - newAcceptanceProb))
)

logStepSize = (
shrinkageTarget
- (avgAcceptanceProb * Math.sqrt(iteration.toDouble) / stepSizeUpdateDenom)
- (avgError * Math.sqrt(iteration.toDouble) / stepSizeUpdateDenom)
)

logStepSizeBar = (stepSizeMultiplier * logStepSize
+ (1.0 - stepSizeMultiplier) * logStepSizeBar)

FINEST.log("warmup iteration %d, avgAcceptanceProb %f, logStepSize %f",
iteration,
avgAcceptanceProb,
logStepSize)
}
}

Expand All @@ -50,38 +45,34 @@ private object DualAvg {
delta = delta,
logStepSize = Math.log(stepSize),
logStepSizeBar = 0.0,
avgAcceptanceProb = 0.0,
avgError = 0.0,
iteration = 0,
shrinkageTarget = Math.log(10 * stepSize)
)

def findStepSize(lf: LeapFrog,
params: Array[Double],
delta: Double,
nSteps: Int,
iterations: Int)(implicit rng: RNG): Double = {
FINE.log("Finding reasonable initial step size")
val stepSize0 = findReasonableStepSize(lf, params)
FINE.log("Found initial step size of %f", stepSize0)

def findStepSize(delta: Double, stepSize0: Double, iterations: Int)(
fn: Double => Double): Double = {
if (stepSize0 == 0.0)
0.0
else {
val dualAvg = DualAvg(delta, stepSize0)
var i = 0
while (i < iterations) {
val logAcceptanceProb = lf.step(params, nSteps, dualAvg.stepSize)
val logAcceptanceProb = fn(dualAvg.stepSize)
dualAvg.update(logAcceptanceProb)
i += 1

FINER
.atMostEvery(1, SECONDS)
.log("Warmup iteration %d of %d, stepSize %f, acceptance prob %f",
i,
iterations,
dualAvg.stepSize,
Math.exp(logAcceptanceProb))
.log(
"iteration %d of %d, stepSize %f, acceptance %f, error %f",
i,
iterations,
dualAvg.stepSize,
Math.exp(logAcceptanceProb),
dualAvg.avgError
)

i += 1
}
dualAvg.finalStepSize
}
Expand All @@ -94,8 +85,7 @@ private object DualAvg {
exponent: Double): Boolean =
exponent * logAcceptanceProb > -exponent * Math.log(2)

private def findReasonableStepSize(lf: LeapFrog,
params: Array[Double]): Double = {
def findReasonableStepSize(lf: LeapFrog, params: Array[Double]): Double = {
var stepSize = 1.0
var logAcceptanceProb = lf.tryStepping(params, stepSize)
val exponent = computeExponent(logAcceptanceProb)
Expand Down
44 changes: 38 additions & 6 deletions rainier-core/src/main/scala/com/stripe/rainier/sampler/Ehmc.scala
Original file line number Diff line number Diff line change
@@ -1,44 +1,76 @@
package com.stripe.rainier.sampler

import scala.collection.mutable.ListBuffer
import Log._
import java.util.concurrent.TimeUnit._

/**
* Empirical HMC - automated tuning of step size and number of leapfrog steps
* @param l0 the initial number of leapfrog steps to use during the dual averaging phase
* @param k the number of iterations to use when determining the
* empirical distribution of the total number of leapfrog steps until a u-turn
*/
final case class EHMC(l0: Int, k: Int = 2000) extends Sampler {
final case class EHMC(l0: Int, k: Int) extends Sampler {

def sample(density: DensityFunction,
warmupIterations: Int,
iterations: Int,
keepEvery: Int)(implicit rng: RNG): List[Array[Double]] = {

val lf = LeapFrog(density)
val params = lf.initialize

FINE.log("Finding reasonable initial step size")
val stepSize0 = DualAvg.findReasonableStepSize(lf, params)
FINE.log("Found initial step size of %f", stepSize0)

FINE.log("Warming up for %d iterations", warmupIterations)
val stepSize =
DualAvg.findStepSize(lf, params, 0.65, l0, warmupIterations)
DualAvg.findStepSize(0.65, stepSize0, warmupIterations) { ss =>
lf.step(params, 1, ss)
}
FINE.log("Found step size of %f", stepSize)

val empiricalL: Vector[Int] =
lf.empiricalLongestSteps(params, l0, k, stepSize)
FINE.log("Sampling %d path lengths", k)
val empiricalL = Vector.fill(k) {
lf.longestBatchStep(params, l0, stepSize)._2
}
val sorted = empiricalL.toList.sorted
FINE.log("Using a range of %d to %d steps", sorted.head, sorted.last)

if (stepSize == 0.0)
List(lf.variables(params))
else {
val buf = new ListBuffer[Array[Double]]
var i = 0

while (i < iterations) {
FINE.log("Sampling for %d iterations", iterations)

var acceptSum = 0.0
while (i < iterations) {
val j = rng.int(k)
val nSteps = empiricalL(j)

lf.step(params, nSteps, stepSize)
FINER
.atMostEvery(1, SECONDS)
.log("Sampling iteration %d of %d for %d steps, acceptance rate %f",
i,
iterations,
nSteps,
(acceptSum / i))

val logAccept = lf.step(params, nSteps, stepSize)
acceptSum += Math.exp(logAccept)
if (i % keepEvery == 0)
buf += lf.variables(params)
i += 1
}
FINE.log("Finished sampling, acceptance rate %f", (acceptSum / i))
buf.toList
}
}
}

object EHMC {
def apply(l0: Int): EHMC = EHMC(l0, 100)
}
23 changes: 17 additions & 6 deletions rainier-core/src/main/scala/com/stripe/rainier/sampler/HMC.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,15 @@ final case class HMC(nSteps: Int) extends Sampler {
val lf = LeapFrog(density)
val params = lf.initialize

FINE.log("Finding reasonable initial step size")
val stepSize0 = DualAvg.findReasonableStepSize(lf, params)
FINE.log("Found initial step size of %f", stepSize0)

FINE.log("Finding step size using %d warmup iterations", warmupIterations)
val stepSize =
DualAvg.findStepSize(lf, params, 0.65, nSteps, warmupIterations)
DualAvg.findStepSize(0.65, stepSize0, warmupIterations) { ss =>
lf.step(params, nSteps, ss)
}
FINE.log("Found step size of %f", stepSize)

if (stepSize == 0.0) {
Expand All @@ -25,16 +31,21 @@ final case class HMC(nSteps: Int) extends Sampler {
var i = 0
FINE.log("Sampling for %d iterations", iterations)

var acceptSum = 0.0
while (i < iterations) {
FINER
.atMostEvery(1, SECONDS)
.log("Sampling iteration %d of %d", i, iterations)
lf.step(params, nSteps, stepSize)
val logAccept = lf.step(params, nSteps, stepSize)
acceptSum += Math.exp(logAccept)
if (i % keepEvery == 0)
buf += lf.variables(params)
i += 1
FINER
.atMostEvery(1, SECONDS)
.log("Sampling iteration %d of %d, acceptance rate %f",
i,
iterations,
(acceptSum / i))
}
FINE.log("Finished sampling")
FINE.log("Finished sampling, acceptance rate %f", (acceptSum / i))

buf.toList
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.stripe.rainier.sampler

import Log._
import java.util.concurrent.TimeUnit._

private[sampler] case class LeapFrog(density: DensityFunction) {
/*
Expand Down Expand Up @@ -118,6 +119,10 @@ private[sampler] case class LeapFrog(density: DensityFunction) {
} else {
copy(isUturnBuf, pqBuf)
}

FINER
.atMostEvery(1, SECONDS)
.log("%d steps until u-turn", l)
l
}

Expand All @@ -127,9 +132,8 @@ private[sampler] case class LeapFrog(density: DensityFunction) {
* @param l0 the initial number of steps
* @param stepSize the current value of the leapfrog step size
*/
private def longestBatchStep(params: Array[Double],
l0: Int,
stepSize: Double)(implicit rng: RNG): Int = {
def longestBatchStep(params: Array[Double], l0: Int, stepSize: Double)(
implicit rng: RNG): (Double, Int) = {

initializePs(params)
copy(params, pqBuf)
Expand All @@ -138,19 +142,9 @@ private[sampler] case class LeapFrog(density: DensityFunction) {
val a = logAcceptanceProb(params, pqBuf)
if (math.log(u) < a)
copy(pqBuf, params)
l
(a, l)
}

/**
* Calculate a vector representing the empirical distribution
* of the steps taken until a u-turn
*/
def empiricalLongestSteps(params: Array[Double],
l0: Int,
k: Int,
stepSize: Double)(implicit rng: RNG): Vector[Int] =
Vector.fill(k)(longestBatchStep(params, l0, stepSize))

private def copy(sourceArray: Array[Double],
targetArray: Array[Double]): Unit =
System.arraycopy(sourceArray, 0, targetArray, 0, inputOutputSize)
Expand Down

0 comments on commit a40d656

Please sign in to comment.