Skip to content

Commit

Permalink
Merge pull request #175 from stripe/minmax
Browse files Browse the repository at this point in the history
Rectifier
  • Loading branch information
avi-stripe authored Jul 13, 2018
2 parents 4f8cde7 + 0b57a58 commit 71ea4af
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ private object Gradient {
case ExpOp => gradient.toReal * child
case AbsOp =>
If(child.original, gradient.toReal * child.original / child, Real.zero)
case RectifierOp =>
If(child.original < 0, Real.zero, gradient.toReal)
}
}

Expand Down
22 changes: 16 additions & 6 deletions rainier-core/src/main/scala/com/stripe/rainier/compute/Real.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,14 @@ sealed trait Real {
def exp: Real = RealOps.unary(this, ir.ExpOp)
def log: Real = RealOps.unary(this, ir.LogOp)

//because abs does not have a smooth derivative, try to avoid using it
//because abs a does not have a smooth derivative, try to avoid using it
def abs: Real = RealOps.unary(this, ir.AbsOp)
def rectifier: Real = RealOps.unary(this, ir.RectifierOp)

def >(other: Real): Real = RealOps.isPositive(this - other)
def <(other: Real): Real = RealOps.isNegative(this - other)
def >=(other: Real): Real = Real.one - (this < other)
def <=(other: Real): Real = Real.one - (this > other)
def >(other: Real): Real = (this - other).rectifier
def <(other: Real): Real = (other - this).rectifier
def >=(other: Real): Real = If(this - other, this > other, Real.one)
def <=(other: Real): Real = If(this - other, this < other, Real.one)

lazy val variables: List[Variable] = RealOps.variables(this)
lazy val gradient: List[Real] = Gradient.derive(variables, this)
Expand All @@ -48,9 +49,18 @@ object Real {
def seq[A](as: Seq[A])(implicit toReal: ToReal[A]): Seq[Real] =
as.map(toReal(_))

def sum(seq: Seq[Real]): Real =
def sum(seq: Iterable[Real]): Real =
seq.foldLeft(Real.zero)(_ + _)

def logSumExp(seq: Iterable[Real]): Real = {
val max = seq.reduce(_ max _)
val shifted = seq.map { x =>
x - max
}
val summed = Real.sum(shifted.map(_.exp))
summed.log + max
}

//print out Scala code that is equivalent to what the Compiler
//would produce as JVM bytecode
def trace(real: Real): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@ private[compute] object RealOps {
case LogOp =>
throw new ArithmeticException(
"Cannot take the log of a negative number")
case AbsOp => Infinity
case AbsOp => Infinity
case RectifierOp => Real.zero
}
case Constant(Real.BigZero) =>
op match {
case ExpOp => Real.one
case LogOp => NegInfinity
case AbsOp => Real.zero
case ExpOp => Real.one
case LogOp => NegInfinity
case AbsOp => Real.zero
case RectifierOp => Real.zero
}
case Constant(value) =>
op match {
Expand All @@ -31,6 +33,11 @@ private[compute] object RealOps {
else
Real(Math.log(value.toDouble))
case AbsOp => Real(value.abs)
case RectifierOp =>
if (value.toDouble < 0)
Real.zero
else
original
}
case nc: NonConstant =>
val opt = (op, nc) match {
Expand Down Expand Up @@ -100,10 +107,10 @@ private[compute] object RealOps {
}

def min(left: Real, right: Real): Real =
((left - right).abs - (right + left)) / 2.0
If(left < right, left, right)

def max(left: Real, right: Real): Real =
((left - right).abs + (right + left)) / 2.0
If(left > right, left, right)

def pow(original: Real, exponent: Real): Real =
exponent match {
Expand Down Expand Up @@ -146,15 +153,6 @@ private[compute] object RealOps {
else
BigDecimal(Math.pow(a.toDouble, b.toDouble))

def isPositive(real: Real): Real =
If(real, nonZeroIsPositive(real), Real.zero)

def isNegative(real: Real): Real =
If(real, Real.one - nonZeroIsPositive(real), Real.zero)

private def nonZeroIsPositive(real: Real): Real =
((real.abs / real) + 1) / 2

def variables(real: Real): List[Variable] = {
var seen = Set.empty[Real]
var vars = List.empty[Variable]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,14 +224,11 @@ case class Mixture(components: Map[Continuous, Real]) extends Continuous {

def logDensity(real: Real): Real =
Real
.sum(
components.map {
case (dist, weight) => {
dist.logDensity(real).exp * weight
}
}.toSeq
)
.log
.logSumExp(components.map {
case (dist, weight) => {
dist.logDensity(real) + weight.log
}
})

def param: RandomVariable[Real] = {
val x = new Variable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,14 @@ object UnboundedSupport extends Support {
* A support representing a bounded (min, max) interval.
*/
case class BoundedSupport(min: Real, max: Real) extends Support {
private def logistic(v: Variable): Real =
(Real.one / (Real.one + (v * -1).exp))

def transform(v: Variable): Real =
(Real.one / (Real.one + (v * -1).exp)) * (max - min) + min
logistic(v) * (max - min) + min

def logJacobian(v: Variable): Real =
transform(v).log + (1 - transform(v)).log + (max - min).log
logistic(v).log + (1 - logistic(v)).log + (max - min).log

def isDefinedAt(real: Real): Real = (real > min) * (real < max)
}
Expand All @@ -97,7 +100,7 @@ case class BoundedAboveSupport(max: Real = Real.zero) extends Support {
def transform(v: Variable): Real =
max - (-1 * v).exp

def logJacobian(v: Variable): Real = v
def logJacobian(v: Variable): Real = v * -1

def isDefinedAt(real: Real): Real = (real < max)
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@ import com.stripe.rainier.internal.asm.Opcodes._
import com.stripe.rainier.internal.asm.tree.MethodNode
import com.stripe.rainier.internal.asm.Label

object MathOps {
def rectifier(x: Double): Double =
if (x < 0.0)
0.0
else
x
}

private trait MethodGenerator {
def access = {
if (isStatic)
Expand Down Expand Up @@ -76,13 +84,14 @@ private trait MethodGenerator {
}

def unaryOp(op: UnaryOp): Unit = {
val methodName = op match {
case LogOp => "log"
case ExpOp => "exp"
case AbsOp => "abs"
val (className, methodName) = op match {
case LogOp => ("java/lang/Math", "log")
case ExpOp => ("java/lang/Math", "exp")
case AbsOp => ("java/lang/Math", "abs")
case RectifierOp => ("com/stripe/rainier/ir/MathOps", "rectifier")
}
methodNode.visitMethodInsn(INVOKESTATIC,
"java/lang/Math",
className,
methodName,
"(D)D",
false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ sealed trait UnaryOp
case object ExpOp extends UnaryOp
case object LogOp extends UnaryOp
case object AbsOp extends UnaryOp
case object RectifierOp extends UnaryOp
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,9 @@ object Tracer {

private def name(u: UnaryOp): String =
u match {
case LogOp => "Math.log"
case ExpOp => "Math.exp"
case AbsOp => "Math.abs"
case LogOp => "Math.log"
case ExpOp => "Math.exp"
case AbsOp => "Math.abs"
case RectifierOp => "MathOps.rectifier"
}
}

0 comments on commit 71ea4af

Please sign in to comment.