From 673408e8663af954e41e5250cb1878b1be581c43 Mon Sep 17 00:00:00 2001 From: Arthur Bit-Monnot Date: Tue, 26 Feb 2019 23:21:17 +0100 Subject: [PATCH] Implementation of trigonometric functions (#331) * Implementation of cos and sin * Implementation of tan, asin, acos, atan * Simplifications in derivate expressions * Additional tests for sinh, cosh and tanh --- .../com/stripe/rainier/compute/Gradient.scala | 11 +++- .../com/stripe/rainier/compute/Real.scala | 17 +++++- .../com/stripe/rainier/compute/RealOps.scala | 60 ++++++++++++++++--- .../scala/com/stripe/rainier/ir/IRViz.scala | 14 +++-- .../stripe/rainier/ir/MethodGenerator.scala | 14 +++-- .../scala/com/stripe/rainier/ir/Ops.scala | 8 +++ .../com/stripe/rainier/compute/RealTest.scala | 60 +++++++++++++++---- 7 files changed, 154 insertions(+), 30 deletions(-) diff --git a/rainier-core/src/main/scala/com/stripe/rainier/compute/Gradient.scala b/rainier-core/src/main/scala/com/stripe/rainier/compute/Gradient.scala index 5c0f12d38..5d7052a66 100644 --- a/rainier-core/src/main/scala/com/stripe/rainier/compute/Gradient.scala +++ b/rainier-core/src/main/scala/com/stripe/rainier/compute/Gradient.scala @@ -107,7 +107,16 @@ private object Gradient { Real.zero, Real.zero, gradient.toReal * child.original / child) - case NoOp => gradient.toReal + case NoOp => gradient.toReal + case SinOp => gradient.toReal * child.original.cos + case CosOp => gradient.toReal * (Real.zero - child.original.sin) + case TanOp => gradient.toReal / child.original.cos.pow(2) + case AsinOp => + gradient.toReal / (Real.one - child.original.pow(2)).pow(0.5) + case AcosOp => + -gradient.toReal / (Real.one - child.original.pow(2)).pow(0.5) + case AtanOp => + gradient.toReal / (Real.one + child.original.pow(2)) } } diff --git a/rainier-core/src/main/scala/com/stripe/rainier/compute/Real.scala b/rainier-core/src/main/scala/com/stripe/rainier/compute/Real.scala index 1a031ba4b..2fdf93d8f 100644 --- a/rainier-core/src/main/scala/com/stripe/rainier/compute/Real.scala +++ b/rainier-core/src/main/scala/com/stripe/rainier/compute/Real.scala @@ -19,7 +19,8 @@ sealed trait Real { def +(other: Real): Real = RealOps.add(this, other) def *(other: Real): Real = RealOps.multiply(this, other) - def -(other: Real): Real = this + (other * -1) + def unary_- : Real = this * (-1) + def -(other: Real): Real = this + (-other) def /(other: Real): Real = RealOps.divide(this, other) def min(other: Real): Real = RealOps.min(this, other) @@ -30,6 +31,18 @@ sealed trait Real { def exp: Real = RealOps.unary(this, ir.ExpOp) def log: Real = RealOps.unary(this, ir.LogOp) + def sin: Real = RealOps.unary(this, ir.SinOp) + def cos: Real = RealOps.unary(this, ir.CosOp) + def tan: Real = RealOps.unary(this, ir.TanOp) + + def asin: Real = RealOps.unary(this, ir.AsinOp) + def acos: Real = RealOps.unary(this, ir.AcosOp) + def atan: Real = RealOps.unary(this, ir.AtanOp) + + def sinh: Real = (this.exp - (-this).exp) / 2 + def cosh: Real = (this.exp + (-this).exp) / 2 + def tanh: Real = this.sinh / this.cosh + //because abs does not have a smooth derivative, try to avoid using it def abs: Real = RealOps.unary(this, ir.AbsOp) @@ -89,9 +102,11 @@ object Real { private[compute] val BigZero = BigDecimal(0.0) private[compute] val BigOne = BigDecimal(1.0) private[compute] val BigTwo = BigDecimal(2.0) + private[compute] val BigPi = BigDecimal(math.Pi) val zero: Real = Constant(BigZero) val one: Real = Constant(BigOne) val two: Real = Constant(BigTwo) + val Pi: Real = Constant(BigPi) val infinity: Real = Infinity val negInfinity: Real = NegInfinity } diff --git a/rainier-core/src/main/scala/com/stripe/rainier/compute/RealOps.scala b/rainier-core/src/main/scala/com/stripe/rainier/compute/RealOps.scala index b9a372504..302cb8b92 100644 --- a/rainier-core/src/main/scala/com/stripe/rainier/compute/RealOps.scala +++ b/rainier-core/src/main/scala/com/stripe/rainier/compute/RealOps.scala @@ -6,7 +6,25 @@ private[compute] object RealOps { def unary(original: Real, op: UnaryOp): Real = original match { - case Infinity => Infinity + case Infinity => + op match { + case ExpOp => Infinity + case LogOp => Infinity + case AbsOp => Infinity + case SinOp => + throw new ArithmeticException( + "No limit for 'sin' at positive infinity") + case CosOp => + throw new ArithmeticException( + "No limit for 'cos' at positive infinity") + case TanOp => + throw new ArithmeticException( + "No limit for 'tan' at positive infinity") + case AcosOp => throw new ArithmeticException("acos undefined above 1") + case AsinOp => throw new ArithmeticException("asin undefined above 1") + case AtanOp => Real.Pi / 2 + case NoOp => Infinity + } case NegInfinity => op match { case ExpOp => Real.zero @@ -14,14 +32,34 @@ private[compute] object RealOps { throw new ArithmeticException( "Cannot take the log of a negative number") case AbsOp => Infinity - case NoOp => original + case SinOp => + throw new ArithmeticException( + "No limit for 'sin' at negative infinity") + case CosOp => + throw new ArithmeticException( + "No limit for 'cos' at negative infinity") + case TanOp => + throw new ArithmeticException( + "No limit for 'tan' at negative infinity") + case AcosOp => + throw new ArithmeticException("acos undefined below -1") + case AsinOp => + throw new ArithmeticException("asin undefined below -1") + case AtanOp => -Real.Pi / 2 + case NoOp => original } case Constant(Real.BigZero) => op match { - case ExpOp => Real.one - case LogOp => NegInfinity - case AbsOp => Real.zero - case NoOp => original + case ExpOp => Real.one + case LogOp => NegInfinity + case AbsOp => Real.zero + case SinOp => Real.zero + case CosOp => Real.one + case TanOp => Real.zero + case AsinOp => Real.zero + case AcosOp => Real.Pi / 2 + case AtanOp => Real.zero + case NoOp => original } case Constant(value) => op match { @@ -32,8 +70,14 @@ private[compute] object RealOps { "Cannot take the log of " + value.toDouble) else Real(Math.log(value.toDouble)) - case AbsOp => Real(value.abs) - case NoOp => original + case AbsOp => Real(value.abs) + case SinOp => Real(Math.sin(value.toDouble)) + case CosOp => Real(Math.cos(value.toDouble)) + case TanOp => Real(Math.tan(value.toDouble)) + case AsinOp => Real(Math.asin(value.toDouble)) + case AcosOp => Real(Math.acos(value.toDouble)) + case AtanOp => Real(Math.atan(value.toDouble)) + case NoOp => original } case nc: NonConstant => val opt = (op, nc) match { diff --git a/rainier-core/src/main/scala/com/stripe/rainier/ir/IRViz.scala b/rainier-core/src/main/scala/com/stripe/rainier/ir/IRViz.scala index 967fc331c..900e5002a 100644 --- a/rainier-core/src/main/scala/com/stripe/rainier/ir/IRViz.scala +++ b/rainier-core/src/main/scala/com/stripe/rainier/ir/IRViz.scala @@ -139,10 +139,16 @@ object IRViz { def opLabel(op: UnaryOp): String = op match { - case ExpOp => "exp" - case LogOp => "ln" - case AbsOp => "abs" - case NoOp => "nop" + case ExpOp => "exp" + case LogOp => "ln" + case AbsOp => "abs" + case SinOp => "sin" + case CosOp => "cos" + case TanOp => "tan" + case AsinOp => "asin" + case AcosOp => "acos" + case AtanOp => "atan" + case NoOp => "nop" } def opLabel(op: BinaryOp): String = diff --git a/rainier-core/src/main/scala/com/stripe/rainier/ir/MethodGenerator.scala b/rainier-core/src/main/scala/com/stripe/rainier/ir/MethodGenerator.scala index 7534cf06d..d7c1d39e3 100644 --- a/rainier-core/src/main/scala/com/stripe/rainier/ir/MethodGenerator.scala +++ b/rainier-core/src/main/scala/com/stripe/rainier/ir/MethodGenerator.scala @@ -73,10 +73,16 @@ private trait MethodGenerator { def unaryOp(op: UnaryOp): Unit = { (op match { - case LogOp => Some(("java/lang/Math", "log")) - case ExpOp => Some(("java/lang/Math", "exp")) - case AbsOp => Some(("java/lang/Math", "abs")) - case NoOp => None + case LogOp => Some(("java/lang/Math", "log")) + case ExpOp => Some(("java/lang/Math", "exp")) + case AbsOp => Some(("java/lang/Math", "abs")) + case CosOp => Some(("java/lang/Math", "cos")) + case SinOp => Some(("java/lang/Math", "sin")) + case TanOp => Some(("java/lang/Math", "tan")) + case AsinOp => Some(("java/lang/Math", "asin")) + case AcosOp => Some(("java/lang/Math", "acos")) + case AtanOp => Some(("java/lang/Math", "atan")) + case NoOp => None }).foreach { case (className, methodName) => methodNode.visitMethodInsn(INVOKESTATIC, diff --git a/rainier-core/src/main/scala/com/stripe/rainier/ir/Ops.scala b/rainier-core/src/main/scala/com/stripe/rainier/ir/Ops.scala index 85ec3b8aa..02f9ff435 100644 --- a/rainier-core/src/main/scala/com/stripe/rainier/ir/Ops.scala +++ b/rainier-core/src/main/scala/com/stripe/rainier/ir/Ops.scala @@ -27,3 +27,11 @@ case object ExpOp extends UnaryOp case object LogOp extends UnaryOp case object AbsOp extends UnaryOp case object NoOp extends UnaryOp + +case object SinOp extends UnaryOp +case object CosOp extends UnaryOp +case object TanOp extends UnaryOp + +case object AsinOp extends UnaryOp +case object AcosOp extends UnaryOp +case object AtanOp extends UnaryOp diff --git a/rainier-tests/src/test/scala/com/stripe/rainier/compute/RealTest.scala b/rainier-tests/src/test/scala/com/stripe/rainier/compute/RealTest.scala index 56000736c..55b7f4e72 100644 --- a/rainier-tests/src/test/scala/com/stripe/rainier/compute/RealTest.scala +++ b/rainier-tests/src/test/scala/com/stripe/rainier/compute/RealTest.scala @@ -4,8 +4,10 @@ import org.scalatest._ import com.stripe.rainier.core._ import scala.util.{Try, Success, Failure} class RealTest extends FunSuite { - def run(description: String, testDeriv: Double => Boolean = _ => true)( - fn: Real => Real): Unit = { + def run(description: String, + defined: Double => Boolean = _ => true, + derivable: Double => Boolean = _ => true, + reference: Double => Double = null)(fn: Real => Real): Unit = { test(description) { val x = new Variable val result = fn(x) @@ -17,13 +19,17 @@ class RealTest extends FunSuite { case Success(Constant(bd)) => bd.toDouble case Failure(_: ArithmeticException) => 0.0 / 0.0 case Failure(_: NumberFormatException) => 0.0 / 0.0 + case Failure(e) => throw e case x => sys.error("Non-constant value " + x) } val c = Compiler(200, 100).compile(List(x), result) val dc = Compiler(200, 100).compile(List(x), deriv) - List(1.0, 0.0, -1.0, 2.0, -2.0, 0.5, -0.5).foreach { n => + List(1.0, 0.0, -1.0, 2.0, -2.0, 0.5, -0.5).filter(defined).foreach { n => val constant = evalAt(n) + if (reference != null) { + assertWithinEpsilon(constant, reference(n), s"[c/ref, n=$n]") + } val eval = new Evaluator(Map(x -> n)) val withVar = eval.toDouble(result) assertWithinEpsilon(constant, withVar, s"[c/ev, n=$n]") @@ -31,7 +37,7 @@ class RealTest extends FunSuite { assertWithinEpsilon(withVar, compiled, s"[ev/ir, n=$n]") // derivatives of automated differentiation vs numeric differentiation - if (testDeriv(n)) { + if (derivable(n)) { val dx = 10E-6 val numDiff = (evalAt(n + dx) - evalAt(n - dx)) / (dx * 2) val diffWithVar = eval.toDouble(deriv) @@ -70,6 +76,36 @@ class RealTest extends FunSuite { run("log") { x => x.abs.log } + run("sin", reference = math.sin) { x => + x.sin + } + run("cos", reference = math.cos) { x => + x.cos + } + run("tan", reference = math.tan) { x => + x.tan + } + run("asin", defined = x => x > -1 && x < 1, reference = math.asin) { x => + x.asin + } + run("acos", defined = x => x > -1 && x < 1, reference = math.acos) { x => + x.acos + } + run("atan", reference = math.atan) { x => + x.atan + } + run("sinh", reference = math.sinh) { x => + x.sinh + } + run("cosh", reference = math.cosh) { x => + x.cosh + } + run("tanh", reference = math.tanh) { x => + x.tanh + } + run("cos(x^2)") { x => + (x * x).cos + } run("temp") { x => val t = x * 3 t + t @@ -77,19 +113,19 @@ class RealTest extends FunSuite { run("abs") { x => x.abs } - run("max(x, 0)", testDeriv = _ != 0) { x => + run("max(x, 0)", derivable = _ != 0) { x => x.max(0) } run("max(x, x)") { x => x.max(x) } - run("x > 0 ? x^2 : 1", testDeriv = _ != 0) { x => + run("x > 0 ? x^2 : 1", derivable = _ != 0) { x => Real.gt(x, 0, x * x, 1) } - run("x > 0 ? 1 : x + 1", testDeriv = _ != 0) { x => + run("x > 0 ? 1 : x + 1", derivable = _ != 0) { x => Real.gt(x, 0, 1, x + 1) } - run("x > 0 ? x^2 : x + 1", _ != 0) { x => + run("x > 0 ? x^2 : x + 1", derivable = _ != 0) { x => Real.gt(x, 0, x * x, x + 1) } @@ -108,7 +144,7 @@ class RealTest extends FunSuite { Real.one / (x.exp + 1) } - run("log x^2", testDeriv = _ != 0) { x => + run("log x^2", derivable = _ != 0) { x => x.pow(2).log } @@ -124,20 +160,20 @@ class RealTest extends FunSuite { (x * x * x)) } - run("lookup", testDeriv = _ => false) { x => // not derivable + run("lookup", derivable = _ => false) { x => // not derivable val i = x.abs * 2 //should be a non-negative whole number Lookup(i, Real.seq(List(0, 1, 2, 3, 4))) } val exponents = scala.util.Random.shuffle(-40.to(40)) - run("exponent sums", testDeriv = _ != 0) { x => + run("exponent sums", derivable = _ != 0) { x => exponents.foldLeft(x) { case (a, e) => (a + x.pow(e)) * x } } - run("pow", testDeriv = _ >= 0) { x => + run("pow", derivable = _ >= 0) { x => x.pow(x) }