From 7c71461f5de570b081c8fa813986f47238aa337d Mon Sep 17 00:00:00 2001 From: Arthur Bit-Monnot Date: Thu, 28 Feb 2019 17:40:55 +0100 Subject: [PATCH] Minor fixes related to handling unusual double values (+/- infinity and NaN). (#333) --- .../stripe/rainier/compute/Evaluator.scala | 12 +- .../com/stripe/rainier/compute/RealOps.scala | 2 + .../com/stripe/rainier/compute/ToReal.scala | 2 + .../com/stripe/rainier/compute/RealTest.scala | 107 +++++++++++------- 4 files changed, 76 insertions(+), 47 deletions(-) diff --git a/rainier-core/src/main/scala/com/stripe/rainier/compute/Evaluator.scala b/rainier-core/src/main/scala/com/stripe/rainier/compute/Evaluator.scala index 9f514f2ce..b26d920ed 100644 --- a/rainier-core/src/main/scala/com/stripe/rainier/compute/Evaluator.scala +++ b/rainier-core/src/main/scala/com/stripe/rainier/compute/Evaluator.scala @@ -4,6 +4,8 @@ class Evaluator(var cache: Map[Real, Double]) extends Numeric[Real] { def toDouble(x: Real): Double = x match { case Constant(v) => v.toDouble + case Infinity => Double.PositiveInfinity + case NegInfinity => Double.NegativeInfinity case _ => cache.get(x) match { case Some(v) => v @@ -22,14 +24,16 @@ class Evaluator(var cache: Map[Real, Double]) extends Numeric[Real] { case l: Line => l.ax.toList.map { case (r, d) => toDouble(r) * d.toDouble }.sum + l.b.toDouble case l: LogLine => - l.ax.toList - .map { case (r, d) => Math.pow(toDouble(r), d.toDouble) } - .reduce(_ * _) + l.ax.toList.map { case (r, d) => Math.pow(toDouble(r), d.toDouble) }.product case Unary(original, op) => - eval(RealOps.unary(Constant(toDouble(original)), op)) + // must use Real(_) constructor since Constant(_) constructor would result in undesirable errors + // at infinities and unhelpful NumberFormatException on NaN, all due to conversion to BigDecimal + val ev = Real(toDouble(original)) + eval(RealOps.unary(ev, op)) case Compare(left, right) => eval(RealOps.compare(toDouble(left), toDouble(right))) case Pow(base, exponent) => + // note: result can be NaN when base < 0 && exponent is negative and not an int Math.pow(toDouble(base), toDouble(exponent)) case l: Lookup => toDouble(l.table(toDouble(l.index).toInt - l.low)) diff --git a/rainier-core/src/main/scala/com/stripe/rainier/compute/RealOps.scala b/rainier-core/src/main/scala/com/stripe/rainier/compute/RealOps.scala index 302cb8b92..ea1171049 100644 --- a/rainier-core/src/main/scala/com/stripe/rainier/compute/RealOps.scala +++ b/rainier-core/src/main/scala/com/stripe/rainier/compute/RealOps.scala @@ -190,6 +190,8 @@ private[compute] object RealOps { def pow(a: BigDecimal, b: BigDecimal): BigDecimal = if (b.isValidInt) a.pow(b.toInt) + else if (a < Real.BigZero) + throw new ArithmeticException(s"Undefined: $a ^ $b") else BigDecimal(Math.pow(a.toDouble, b.toDouble)) diff --git a/rainier-core/src/main/scala/com/stripe/rainier/compute/ToReal.scala b/rainier-core/src/main/scala/com/stripe/rainier/compute/ToReal.scala index 93e83e551..79641a618 100644 --- a/rainier-core/src/main/scala/com/stripe/rainier/compute/ToReal.scala +++ b/rainier-core/src/main/scala/com/stripe/rainier/compute/ToReal.scala @@ -18,6 +18,8 @@ trait LowPriToReal { Real.negInfinity else if (double.isInfinity) Real.infinity + else if (double.isNaN) + throw new ArithmeticException("Trying to convert NaN to Real") else Constant(BigDecimal(double)) } diff --git a/rainier-tests/src/test/scala/com/stripe/rainier/compute/RealTest.scala b/rainier-tests/src/test/scala/com/stripe/rainier/compute/RealTest.scala index 55b7f4e72..79e8906be 100644 --- a/rainier-tests/src/test/scala/com/stripe/rainier/compute/RealTest.scala +++ b/rainier-tests/src/test/scala/com/stripe/rainier/compute/RealTest.scala @@ -3,6 +3,8 @@ package com.stripe.rainier.compute import org.scalatest._ import com.stripe.rainier.core._ import scala.util.{Try, Success, Failure} +import Double.{PositiveInfinity => Inf, NegativeInfinity => NegInf, NaN} + class RealTest extends FunSuite { def run(description: String, defined: Double => Boolean = _ => true, @@ -13,43 +15,45 @@ class RealTest extends FunSuite { val result = fn(x) val deriv = result.gradient.head - def evalAt(d: Double): Double = Try { fn(Constant(d)) } match { - case Success(Infinity) => 1.0 / 0.0 - case Success(NegInfinity) => -1.0 / 0.0 - case Success(Constant(bd)) => bd.toDouble - case Failure(_: ArithmeticException) => 0.0 / 0.0 - case Failure(_: NumberFormatException) => 0.0 / 0.0 - case Failure(e) => throw e - case x => sys.error("Non-constant value " + x) + def evalAt(d: Double): Double = Try { fn(Real(d)) } match { + case Success(Infinity) => Inf + case Success(NegInfinity) => NegInf + case Success(Constant(bd)) => bd.toDouble + case Failure(_: ArithmeticException) => NaN + case Failure(e) => throw e + case x => sys.error("Non-constant value " + x) } val c = Compiler(200, 100).compile(List(x), result) val dc = Compiler(200, 100).compile(List(x), deriv) - List(1.0, 0.0, -1.0, 2.0, -2.0, 0.5, -0.5).filter(defined).foreach { n => - val constant = evalAt(n) - if (reference != null) { - assertWithinEpsilon(constant, reference(n), s"[c/ref, n=$n]") - } - val eval = new Evaluator(Map(x -> n)) - val withVar = eval.toDouble(result) - assertWithinEpsilon(constant, withVar, s"[c/ev, n=$n]") - val compiled = c(Array(n)) - assertWithinEpsilon(withVar, compiled, s"[ev/ir, n=$n]") - - // derivatives of automated differentiation vs numeric differentiation - if (derivable(n)) { - val dx = 10E-6 - val numDiff = (evalAt(n + dx) - evalAt(n - dx)) / (dx * 2) - val diffWithVar = eval.toDouble(deriv) - assertWithinEpsilon(numDiff, - diffWithVar, - s"[numDiff/diffWithVar, n=$n]") - val diffCompiled = dc(Array(n)) - assertWithinEpsilon(diffWithVar, - diffCompiled, - s"[diffWithVar/diffCompiled, n=$n]") + List(1.0, 0.0, -1.0, 2.0, -2.0, 0.5, -0.5, NegInf, Inf) + .filter(defined) + .foreach { n => + val constant = evalAt(n) + if (reference != null) { + assertWithinEpsilon(constant, reference(n), s"[c/ref, n=$n]") + } + val eval = new Evaluator(Map(x -> n)) + val withVar = eval.toDouble(result) + assertWithinEpsilon(constant, withVar, s"[c/ev, n=$n]") + val compiled = c(Array(n)) + assertWithinEpsilon(withVar, compiled, s"[ev/ir, n=$n]") + + // derivatives of automated differentiation vs numeric differentiation + // exclude infinite values for which numeric differentiation does not make sense + if (derivable(n) && !n.isInfinite) { + val dx = 10E-6 + val numDiff = (evalAt(n + dx) - evalAt(n - dx)) / (dx * 2) + val diffWithVar = eval.toDouble(deriv) + assertWithinEpsilon(numDiff, + diffWithVar, + s"[numDiff/diffWithVar, n=$n]") + val diffCompiled = dc(Array(n)) + assertWithinEpsilon(diffWithVar, + diffCompiled, + s"[diffWithVar/diffCompiled, n=$n]") + } } - } } } @@ -76,13 +80,13 @@ class RealTest extends FunSuite { run("log") { x => x.abs.log } - run("sin", reference = math.sin) { x => + run("sin", defined = !_.isInfinite, reference = math.sin) { x => x.sin } - run("cos", reference = math.cos) { x => + run("cos", defined = !_.isInfinite, reference = math.cos) { x => x.cos } - run("tan", reference = math.tan) { x => + run("tan", defined = !_.isInfinite, reference = math.tan) { x => x.tan } run("asin", defined = x => x > -1 && x < 1, reference = math.asin) { x => @@ -100,10 +104,16 @@ class RealTest extends FunSuite { run("cosh", reference = math.cosh) { x => x.cosh } - run("tanh", reference = math.tanh) { x => + run("tanh", defined = !_.isInfinite, reference = math.tanh) { x => x.tanh } - run("cos(x^2)") { x => + run("tanh at infty") { x => + // tanh does have limits at infinities but the current explicit representation does not handle it (results in inf/inf = NaN) + // this additional test simply ensure that results are consistent among all implementations, + // without checking against the more informed reference implementation in the JDK + x.tanh + } + run("cos(x^2)", defined = !_.isInfinite) { x => (x * x).cos } run("temp") { x => @@ -129,7 +139,15 @@ class RealTest extends FunSuite { Real.gt(x, 0, x * x, x + 1) } - run("normal") { x => + run("normal", defined = !_.isPosInfinity) { x => + // FIXME: letting this be defined at +infinity results in the following error: + // -Infinity did not equal NaN [c/ev, n=Infinity] (RealTest.scala:67) + // I.e. two different results from constant folding and evaluation. + // same applies to the "normal sum" test case below + Normal(x, 1).logDensity(Real(1d)) + } + + run("normal sum", defined = !_.isPosInfinity) { x => Real.sum(Range.BigDecimal(0d, 2d, 1d).toList.map { y => Normal(x, 1).logDensity(Real(y)) }) @@ -140,11 +158,11 @@ class RealTest extends FunSuite { (logistic * (Real.one - logistic)).log } - run("minimal logistic") { x => + run("minimal logistic", reference = x => 1d / (math.exp(x) + 1)) { x => Real.one / (x.exp + 1) } - run("log x^2", derivable = _ != 0) { x => + run("log x^2", derivable = _ != 0, reference = x => math.log(x * x)) { x => x.pow(2).log } @@ -154,13 +172,16 @@ class RealTest extends FunSuite { }) } - run("4x^3") { x => + run("4x^3", reference = x => 4 * x * x * x) { x => (((((x + x) * x) + (x * x)) * x) + (x * x * x)) } - run("lookup", derivable = _ => false) { x => // not derivable + run("lookup", + defined = x => x.abs <= 2 && (x.abs * 2).isValidInt, + derivable = _ => false, + reference = _.abs * 2) { x => val i = x.abs * 2 //should be a non-negative whole number Lookup(i, Real.seq(List(0, 1, 2, 3, 4))) } @@ -173,7 +194,7 @@ class RealTest extends FunSuite { } } - run("pow", derivable = _ >= 0) { x => + run("pow", defined = _ >= 0) { x => x.pow(x) }