From a40d656baa0509731a9f2c749bfd7947eb1fb6ea Mon Sep 17 00:00:00 2001 From: avi-stripe Date: Sat, 9 Feb 2019 10:12:37 -0800 Subject: [PATCH] cleanups and logging for hmc (#323) --- .../com/stripe/rainier/sampler/DualAvg.scala | 50 ++++++++----------- .../com/stripe/rainier/sampler/Ehmc.scala | 44 +++++++++++++--- .../com/stripe/rainier/sampler/HMC.scala | 23 ++++++--- .../com/stripe/rainier/sampler/LeapFrog.scala | 22 +++----- 4 files changed, 83 insertions(+), 56 deletions(-) diff --git a/rainier-core/src/main/scala/com/stripe/rainier/sampler/DualAvg.scala b/rainier-core/src/main/scala/com/stripe/rainier/sampler/DualAvg.scala index 72fef3185..57570297c 100644 --- a/rainier-core/src/main/scala/com/stripe/rainier/sampler/DualAvg.scala +++ b/rainier-core/src/main/scala/com/stripe/rainier/sampler/DualAvg.scala @@ -7,7 +7,7 @@ final private class DualAvg( delta: Double, var logStepSize: Double, var logStepSizeBar: Double, - var avgAcceptanceProb: Double, + var avgError: Double, var iteration: Int, shrinkageTarget: Double, stepSizeUpdateDenom: Double = 0.05, @@ -20,27 +20,22 @@ final private class DualAvg( def update(logAcceptanceProb: Double): Unit = { val newAcceptanceProb = Math.exp(logAcceptanceProb) iteration = iteration + 1 - val avgAcceptanceProbMultiplier = + val avgErrorMultiplier = 1.0 / (iteration.toDouble + acceptanceProbUpdateDenom) val stepSizeMultiplier = Math.pow(iteration.toDouble, -decayRate) - avgAcceptanceProb = ( - (1.0 - avgAcceptanceProbMultiplier) * avgAcceptanceProb - + (avgAcceptanceProbMultiplier * (delta - newAcceptanceProb)) + avgError = ( + (1.0 - avgErrorMultiplier) * avgError + + (avgErrorMultiplier * (delta - newAcceptanceProb)) ) logStepSize = ( shrinkageTarget - - (avgAcceptanceProb * Math.sqrt(iteration.toDouble) / stepSizeUpdateDenom) + - (avgError * Math.sqrt(iteration.toDouble) / stepSizeUpdateDenom) ) logStepSizeBar = (stepSizeMultiplier * logStepSize + (1.0 - stepSizeMultiplier) * logStepSizeBar) - - FINEST.log("warmup iteration %d, avgAcceptanceProb %f, logStepSize %f", - iteration, - avgAcceptanceProb, - logStepSize) } } @@ -50,38 +45,34 @@ private object DualAvg { delta = delta, logStepSize = Math.log(stepSize), logStepSizeBar = 0.0, - avgAcceptanceProb = 0.0, + avgError = 0.0, iteration = 0, shrinkageTarget = Math.log(10 * stepSize) ) - def findStepSize(lf: LeapFrog, - params: Array[Double], - delta: Double, - nSteps: Int, - iterations: Int)(implicit rng: RNG): Double = { - FINE.log("Finding reasonable initial step size") - val stepSize0 = findReasonableStepSize(lf, params) - FINE.log("Found initial step size of %f", stepSize0) - + def findStepSize(delta: Double, stepSize0: Double, iterations: Int)( + fn: Double => Double): Double = { if (stepSize0 == 0.0) 0.0 else { val dualAvg = DualAvg(delta, stepSize0) var i = 0 while (i < iterations) { - val logAcceptanceProb = lf.step(params, nSteps, dualAvg.stepSize) + val logAcceptanceProb = fn(dualAvg.stepSize) dualAvg.update(logAcceptanceProb) + i += 1 FINER .atMostEvery(1, SECONDS) - .log("Warmup iteration %d of %d, stepSize %f, acceptance prob %f", - i, - iterations, - dualAvg.stepSize, - Math.exp(logAcceptanceProb)) + .log( + "iteration %d of %d, stepSize %f, acceptance %f, error %f", + i, + iterations, + dualAvg.stepSize, + Math.exp(logAcceptanceProb), + dualAvg.avgError + ) - i += 1 } dualAvg.finalStepSize } @@ -94,8 +85,7 @@ private object DualAvg { exponent: Double): Boolean = exponent * logAcceptanceProb > -exponent * Math.log(2) - private def findReasonableStepSize(lf: LeapFrog, - params: Array[Double]): Double = { + def findReasonableStepSize(lf: LeapFrog, params: Array[Double]): Double = { var stepSize = 1.0 var logAcceptanceProb = lf.tryStepping(params, stepSize) val exponent = computeExponent(logAcceptanceProb) diff --git a/rainier-core/src/main/scala/com/stripe/rainier/sampler/Ehmc.scala b/rainier-core/src/main/scala/com/stripe/rainier/sampler/Ehmc.scala index 8bdaf2e1e..d2873eeff 100644 --- a/rainier-core/src/main/scala/com/stripe/rainier/sampler/Ehmc.scala +++ b/rainier-core/src/main/scala/com/stripe/rainier/sampler/Ehmc.scala @@ -1,6 +1,8 @@ package com.stripe.rainier.sampler import scala.collection.mutable.ListBuffer +import Log._ +import java.util.concurrent.TimeUnit._ /** * Empirical HMC - automated tuning of step size and number of leapfrog steps @@ -8,7 +10,8 @@ import scala.collection.mutable.ListBuffer * @param k the number of iterations to use when determining the * empirical distribution of the total number of leapfrog steps until a u-turn */ -final case class EHMC(l0: Int, k: Int = 2000) extends Sampler { +final case class EHMC(l0: Int, k: Int) extends Sampler { + def sample(density: DensityFunction, warmupIterations: Int, iterations: Int, @@ -16,11 +19,24 @@ final case class EHMC(l0: Int, k: Int = 2000) extends Sampler { val lf = LeapFrog(density) val params = lf.initialize + + FINE.log("Finding reasonable initial step size") + val stepSize0 = DualAvg.findReasonableStepSize(lf, params) + FINE.log("Found initial step size of %f", stepSize0) + + FINE.log("Warming up for %d iterations", warmupIterations) val stepSize = - DualAvg.findStepSize(lf, params, 0.65, l0, warmupIterations) + DualAvg.findStepSize(0.65, stepSize0, warmupIterations) { ss => + lf.step(params, 1, ss) + } + FINE.log("Found step size of %f", stepSize) - val empiricalL: Vector[Int] = - lf.empiricalLongestSteps(params, l0, k, stepSize) + FINE.log("Sampling %d path lengths", k) + val empiricalL = Vector.fill(k) { + lf.longestBatchStep(params, l0, stepSize)._2 + } + val sorted = empiricalL.toList.sorted + FINE.log("Using a range of %d to %d steps", sorted.head, sorted.last) if (stepSize == 0.0) List(lf.variables(params)) @@ -28,17 +44,33 @@ final case class EHMC(l0: Int, k: Int = 2000) extends Sampler { val buf = new ListBuffer[Array[Double]] var i = 0 - while (i < iterations) { + FINE.log("Sampling for %d iterations", iterations) + var acceptSum = 0.0 + while (i < iterations) { val j = rng.int(k) val nSteps = empiricalL(j) - lf.step(params, nSteps, stepSize) + FINER + .atMostEvery(1, SECONDS) + .log("Sampling iteration %d of %d for %d steps, acceptance rate %f", + i, + iterations, + nSteps, + (acceptSum / i)) + + val logAccept = lf.step(params, nSteps, stepSize) + acceptSum += Math.exp(logAccept) if (i % keepEvery == 0) buf += lf.variables(params) i += 1 } + FINE.log("Finished sampling, acceptance rate %f", (acceptSum / i)) buf.toList } } } + +object EHMC { + def apply(l0: Int): EHMC = EHMC(l0, 100) +} diff --git a/rainier-core/src/main/scala/com/stripe/rainier/sampler/HMC.scala b/rainier-core/src/main/scala/com/stripe/rainier/sampler/HMC.scala index 1be073113..4dc9c25f3 100644 --- a/rainier-core/src/main/scala/com/stripe/rainier/sampler/HMC.scala +++ b/rainier-core/src/main/scala/com/stripe/rainier/sampler/HMC.scala @@ -12,9 +12,15 @@ final case class HMC(nSteps: Int) extends Sampler { val lf = LeapFrog(density) val params = lf.initialize + FINE.log("Finding reasonable initial step size") + val stepSize0 = DualAvg.findReasonableStepSize(lf, params) + FINE.log("Found initial step size of %f", stepSize0) + FINE.log("Finding step size using %d warmup iterations", warmupIterations) val stepSize = - DualAvg.findStepSize(lf, params, 0.65, nSteps, warmupIterations) + DualAvg.findStepSize(0.65, stepSize0, warmupIterations) { ss => + lf.step(params, nSteps, ss) + } FINE.log("Found step size of %f", stepSize) if (stepSize == 0.0) { @@ -25,16 +31,21 @@ final case class HMC(nSteps: Int) extends Sampler { var i = 0 FINE.log("Sampling for %d iterations", iterations) + var acceptSum = 0.0 while (i < iterations) { - FINER - .atMostEvery(1, SECONDS) - .log("Sampling iteration %d of %d", i, iterations) - lf.step(params, nSteps, stepSize) + val logAccept = lf.step(params, nSteps, stepSize) + acceptSum += Math.exp(logAccept) if (i % keepEvery == 0) buf += lf.variables(params) i += 1 + FINER + .atMostEvery(1, SECONDS) + .log("Sampling iteration %d of %d, acceptance rate %f", + i, + iterations, + (acceptSum / i)) } - FINE.log("Finished sampling") + FINE.log("Finished sampling, acceptance rate %f", (acceptSum / i)) buf.toList } diff --git a/rainier-core/src/main/scala/com/stripe/rainier/sampler/LeapFrog.scala b/rainier-core/src/main/scala/com/stripe/rainier/sampler/LeapFrog.scala index c2b9614c0..4d099653f 100644 --- a/rainier-core/src/main/scala/com/stripe/rainier/sampler/LeapFrog.scala +++ b/rainier-core/src/main/scala/com/stripe/rainier/sampler/LeapFrog.scala @@ -1,6 +1,7 @@ package com.stripe.rainier.sampler import Log._ +import java.util.concurrent.TimeUnit._ private[sampler] case class LeapFrog(density: DensityFunction) { /* @@ -118,6 +119,10 @@ private[sampler] case class LeapFrog(density: DensityFunction) { } else { copy(isUturnBuf, pqBuf) } + + FINER + .atMostEvery(1, SECONDS) + .log("%d steps until u-turn", l) l } @@ -127,9 +132,8 @@ private[sampler] case class LeapFrog(density: DensityFunction) { * @param l0 the initial number of steps * @param stepSize the current value of the leapfrog step size */ - private def longestBatchStep(params: Array[Double], - l0: Int, - stepSize: Double)(implicit rng: RNG): Int = { + def longestBatchStep(params: Array[Double], l0: Int, stepSize: Double)( + implicit rng: RNG): (Double, Int) = { initializePs(params) copy(params, pqBuf) @@ -138,19 +142,9 @@ private[sampler] case class LeapFrog(density: DensityFunction) { val a = logAcceptanceProb(params, pqBuf) if (math.log(u) < a) copy(pqBuf, params) - l + (a, l) } - /** - * Calculate a vector representing the empirical distribution - * of the steps taken until a u-turn - */ - def empiricalLongestSteps(params: Array[Double], - l0: Int, - k: Int, - stepSize: Double)(implicit rng: RNG): Vector[Int] = - Vector.fill(k)(longestBatchStep(params, l0, stepSize)) - private def copy(sourceArray: Array[Double], targetArray: Array[Double]): Unit = System.arraycopy(sourceArray, 0, targetArray, 0, inputOutputSize)