Skip to content

Commit

Permalink
Minor fixes related to handling unusual double values (+/- infinity a…
Browse files Browse the repository at this point in the history
…nd NaN). (#333)
  • Loading branch information
arbimo authored and avi-stripe committed Feb 28, 2019
1 parent 673408e commit 7c71461
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ class Evaluator(var cache: Map[Real, Double]) extends Numeric[Real] {

def toDouble(x: Real): Double = x match {
case Constant(v) => v.toDouble
case Infinity => Double.PositiveInfinity
case NegInfinity => Double.NegativeInfinity
case _ =>
cache.get(x) match {
case Some(v) => v
Expand All @@ -22,14 +24,16 @@ class Evaluator(var cache: Map[Real, Double]) extends Numeric[Real] {
case l: Line =>
l.ax.toList.map { case (r, d) => toDouble(r) * d.toDouble }.sum + l.b.toDouble
case l: LogLine =>
l.ax.toList
.map { case (r, d) => Math.pow(toDouble(r), d.toDouble) }
.reduce(_ * _)
l.ax.toList.map { case (r, d) => Math.pow(toDouble(r), d.toDouble) }.product
case Unary(original, op) =>
eval(RealOps.unary(Constant(toDouble(original)), op))
// must use Real(_) constructor since Constant(_) constructor would result in undesirable errors
// at infinities and unhelpful NumberFormatException on NaN, all due to conversion to BigDecimal
val ev = Real(toDouble(original))
eval(RealOps.unary(ev, op))
case Compare(left, right) =>
eval(RealOps.compare(toDouble(left), toDouble(right)))
case Pow(base, exponent) =>
// note: result can be NaN when base < 0 && exponent is negative and not an int
Math.pow(toDouble(base), toDouble(exponent))
case l: Lookup =>
toDouble(l.table(toDouble(l.index).toInt - l.low))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ private[compute] object RealOps {
def pow(a: BigDecimal, b: BigDecimal): BigDecimal =
if (b.isValidInt)
a.pow(b.toInt)
else if (a < Real.BigZero)
throw new ArithmeticException(s"Undefined: $a ^ $b")
else
BigDecimal(Math.pow(a.toDouble, b.toDouble))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ trait LowPriToReal {
Real.negInfinity
else if (double.isInfinity)
Real.infinity
else if (double.isNaN)
throw new ArithmeticException("Trying to convert NaN to Real")
else
Constant(BigDecimal(double))
}
Expand Down
107 changes: 64 additions & 43 deletions rainier-tests/src/test/scala/com/stripe/rainier/compute/RealTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package com.stripe.rainier.compute
import org.scalatest._
import com.stripe.rainier.core._
import scala.util.{Try, Success, Failure}
import Double.{PositiveInfinity => Inf, NegativeInfinity => NegInf, NaN}

class RealTest extends FunSuite {
def run(description: String,
defined: Double => Boolean = _ => true,
Expand All @@ -13,43 +15,45 @@ class RealTest extends FunSuite {
val result = fn(x)
val deriv = result.gradient.head

def evalAt(d: Double): Double = Try { fn(Constant(d)) } match {
case Success(Infinity) => 1.0 / 0.0
case Success(NegInfinity) => -1.0 / 0.0
case Success(Constant(bd)) => bd.toDouble
case Failure(_: ArithmeticException) => 0.0 / 0.0
case Failure(_: NumberFormatException) => 0.0 / 0.0
case Failure(e) => throw e
case x => sys.error("Non-constant value " + x)
def evalAt(d: Double): Double = Try { fn(Real(d)) } match {
case Success(Infinity) => Inf
case Success(NegInfinity) => NegInf
case Success(Constant(bd)) => bd.toDouble
case Failure(_: ArithmeticException) => NaN
case Failure(e) => throw e
case x => sys.error("Non-constant value " + x)
}

val c = Compiler(200, 100).compile(List(x), result)
val dc = Compiler(200, 100).compile(List(x), deriv)
List(1.0, 0.0, -1.0, 2.0, -2.0, 0.5, -0.5).filter(defined).foreach { n =>
val constant = evalAt(n)
if (reference != null) {
assertWithinEpsilon(constant, reference(n), s"[c/ref, n=$n]")
}
val eval = new Evaluator(Map(x -> n))
val withVar = eval.toDouble(result)
assertWithinEpsilon(constant, withVar, s"[c/ev, n=$n]")
val compiled = c(Array(n))
assertWithinEpsilon(withVar, compiled, s"[ev/ir, n=$n]")

// derivatives of automated differentiation vs numeric differentiation
if (derivable(n)) {
val dx = 10E-6
val numDiff = (evalAt(n + dx) - evalAt(n - dx)) / (dx * 2)
val diffWithVar = eval.toDouble(deriv)
assertWithinEpsilon(numDiff,
diffWithVar,
s"[numDiff/diffWithVar, n=$n]")
val diffCompiled = dc(Array(n))
assertWithinEpsilon(diffWithVar,
diffCompiled,
s"[diffWithVar/diffCompiled, n=$n]")
List(1.0, 0.0, -1.0, 2.0, -2.0, 0.5, -0.5, NegInf, Inf)
.filter(defined)
.foreach { n =>
val constant = evalAt(n)
if (reference != null) {
assertWithinEpsilon(constant, reference(n), s"[c/ref, n=$n]")
}
val eval = new Evaluator(Map(x -> n))
val withVar = eval.toDouble(result)
assertWithinEpsilon(constant, withVar, s"[c/ev, n=$n]")
val compiled = c(Array(n))
assertWithinEpsilon(withVar, compiled, s"[ev/ir, n=$n]")

// derivatives of automated differentiation vs numeric differentiation
// exclude infinite values for which numeric differentiation does not make sense
if (derivable(n) && !n.isInfinite) {
val dx = 10E-6
val numDiff = (evalAt(n + dx) - evalAt(n - dx)) / (dx * 2)
val diffWithVar = eval.toDouble(deriv)
assertWithinEpsilon(numDiff,
diffWithVar,
s"[numDiff/diffWithVar, n=$n]")
val diffCompiled = dc(Array(n))
assertWithinEpsilon(diffWithVar,
diffCompiled,
s"[diffWithVar/diffCompiled, n=$n]")
}
}
}
}
}

Expand All @@ -76,13 +80,13 @@ class RealTest extends FunSuite {
run("log") { x =>
x.abs.log
}
run("sin", reference = math.sin) { x =>
run("sin", defined = !_.isInfinite, reference = math.sin) { x =>
x.sin
}
run("cos", reference = math.cos) { x =>
run("cos", defined = !_.isInfinite, reference = math.cos) { x =>
x.cos
}
run("tan", reference = math.tan) { x =>
run("tan", defined = !_.isInfinite, reference = math.tan) { x =>
x.tan
}
run("asin", defined = x => x > -1 && x < 1, reference = math.asin) { x =>
Expand All @@ -100,10 +104,16 @@ class RealTest extends FunSuite {
run("cosh", reference = math.cosh) { x =>
x.cosh
}
run("tanh", reference = math.tanh) { x =>
run("tanh", defined = !_.isInfinite, reference = math.tanh) { x =>
x.tanh
}
run("cos(x^2)") { x =>
run("tanh at infty") { x =>
// tanh does have limits at infinities but the current explicit representation does not handle it (results in inf/inf = NaN)
// this additional test simply ensure that results are consistent among all implementations,
// without checking against the more informed reference implementation in the JDK
x.tanh
}
run("cos(x^2)", defined = !_.isInfinite) { x =>
(x * x).cos
}
run("temp") { x =>
Expand All @@ -129,7 +139,15 @@ class RealTest extends FunSuite {
Real.gt(x, 0, x * x, x + 1)
}

run("normal") { x =>
run("normal", defined = !_.isPosInfinity) { x =>
// FIXME: letting this be defined at +infinity results in the following error:
// -Infinity did not equal NaN [c/ev, n=Infinity] (RealTest.scala:67)
// I.e. two different results from constant folding and evaluation.
// same applies to the "normal sum" test case below
Normal(x, 1).logDensity(Real(1d))
}

run("normal sum", defined = !_.isPosInfinity) { x =>
Real.sum(Range.BigDecimal(0d, 2d, 1d).toList.map { y =>
Normal(x, 1).logDensity(Real(y))
})
Expand All @@ -140,11 +158,11 @@ class RealTest extends FunSuite {
(logistic * (Real.one - logistic)).log
}

run("minimal logistic") { x =>
run("minimal logistic", reference = x => 1d / (math.exp(x) + 1)) { x =>
Real.one / (x.exp + 1)
}

run("log x^2", derivable = _ != 0) { x =>
run("log x^2", derivable = _ != 0, reference = x => math.log(x * x)) { x =>
x.pow(2).log
}

Expand All @@ -154,13 +172,16 @@ class RealTest extends FunSuite {
})
}

run("4x^3") { x =>
run("4x^3", reference = x => 4 * x * x * x) { x =>
(((((x + x) * x) +
(x * x)) * x) +
(x * x * x))
}

run("lookup", derivable = _ => false) { x => // not derivable
run("lookup",
defined = x => x.abs <= 2 && (x.abs * 2).isValidInt,
derivable = _ => false,
reference = _.abs * 2) { x =>
val i = x.abs * 2 //should be a non-negative whole number
Lookup(i, Real.seq(List(0, 1, 2, 3, 4)))
}
Expand All @@ -173,7 +194,7 @@ class RealTest extends FunSuite {
}
}

run("pow", derivable = _ >= 0) { x =>
run("pow", defined = _ >= 0) { x =>
x.pow(x)
}

Expand Down

0 comments on commit 7c71461

Please sign in to comment.