Skip to content

Commit

Permalink
Fix ToMap encoder (#324)
Browse files Browse the repository at this point in the history
* wip

* RandomVariable.fit and fix Map encoder

* remove println
  • Loading branch information
avi-stripe authored Feb 9, 2019
1 parent a40d656 commit 98cb369
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ object Encoder {

def create(acc: List[Variable]): (Map[String, Real], List[Variable]) =
toMap.fields.foldRight((Map[String, Real](), acc)) {
case (field, (map, _)) =>
case (field, (map, a)) =>
val v = new Variable
(map + (field -> v), v :: acc)
(map + (field -> v), v :: a)
}

def extract(t: T, acc: List[Double]): List[Double] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,14 @@ class RandomVariable[+T](val value: T, val targets: Set[Target]) {
def gradient(index: Int) = outputs(index + 1)
}

def densityAtOrigin: Double = {
val inputs = new Array[Double](dataFn.numInputs)
val globals = new Array[Double](dataFn.numGlobals)
val outputs = new Array[Double](dataFn.numOutputs)
dataFn(inputs, globals, outputs)
outputs(0)
}

lazy val densityValue: Real = targetGroup.base

//this is really just here to allow destructuring in for{}
Expand Down Expand Up @@ -196,4 +204,8 @@ object RandomVariable {

def fill[A](k: Int)(fn: => RandomVariable[A]): RandomVariable[Seq[A]] =
traverse(List.fill(k)(fn))

def fit[L, T](l: L, seq: Seq[T])(
implicit toLH: ToLikelihood[L, T]): RandomVariable[Unit] =
toLH(l).fit(seq)
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ final case class EHMC(l0: Int, k: Int) extends Sampler {
.log("Sampling iteration %d of %d for %d steps, acceptance rate %f",
i,
iterations,
nSteps,
nSteps,
(acceptSum / i))

val logAccept = lf.step(params, nSteps, stepSize)
Expand Down

0 comments on commit 98cb369

Please sign in to comment.