diff --git a/rainier-core/src/main/scala/com/stripe/rainier/compute/Encoder.scala b/rainier-core/src/main/scala/com/stripe/rainier/compute/Encoder.scala index a83c8402d..c012a89b4 100644 --- a/rainier-core/src/main/scala/com/stripe/rainier/compute/Encoder.scala +++ b/rainier-core/src/main/scala/com/stripe/rainier/compute/Encoder.scala @@ -78,9 +78,9 @@ object Encoder { def create(acc: List[Variable]): (Map[String, Real], List[Variable]) = toMap.fields.foldRight((Map[String, Real](), acc)) { - case (field, (map, _)) => + case (field, (map, a)) => val v = new Variable - (map + (field -> v), v :: acc) + (map + (field -> v), v :: a) } def extract(t: T, acc: List[Double]): List[Double] = { diff --git a/rainier-core/src/main/scala/com/stripe/rainier/core/RandomVariable.scala b/rainier-core/src/main/scala/com/stripe/rainier/core/RandomVariable.scala index a3647ada9..0e323cd79 100644 --- a/rainier-core/src/main/scala/com/stripe/rainier/core/RandomVariable.scala +++ b/rainier-core/src/main/scala/com/stripe/rainier/core/RandomVariable.scala @@ -127,6 +127,14 @@ class RandomVariable[+T](val value: T, val targets: Set[Target]) { def gradient(index: Int) = outputs(index + 1) } + def densityAtOrigin: Double = { + val inputs = new Array[Double](dataFn.numInputs) + val globals = new Array[Double](dataFn.numGlobals) + val outputs = new Array[Double](dataFn.numOutputs) + dataFn(inputs, globals, outputs) + outputs(0) + } + lazy val densityValue: Real = targetGroup.base //this is really just here to allow destructuring in for{} @@ -196,4 +204,8 @@ object RandomVariable { def fill[A](k: Int)(fn: => RandomVariable[A]): RandomVariable[Seq[A]] = traverse(List.fill(k)(fn)) + + def fit[L, T](l: L, seq: Seq[T])( + implicit toLH: ToLikelihood[L, T]): RandomVariable[Unit] = + toLH(l).fit(seq) } diff --git a/rainier-core/src/main/scala/com/stripe/rainier/sampler/Ehmc.scala b/rainier-core/src/main/scala/com/stripe/rainier/sampler/Ehmc.scala index d2873eeff..3a522b43c 100644 --- a/rainier-core/src/main/scala/com/stripe/rainier/sampler/Ehmc.scala +++ b/rainier-core/src/main/scala/com/stripe/rainier/sampler/Ehmc.scala @@ -56,7 +56,7 @@ final case class EHMC(l0: Int, k: Int) extends Sampler { .log("Sampling iteration %d of %d for %d steps, acceptance rate %f", i, iterations, - nSteps, + nSteps, (acceptSum / i)) val logAccept = lf.step(params, nSteps, stepSize)