Skip to content

Commit

Permalink
Implementation of trigonometric functions (#331)
Browse files Browse the repository at this point in the history
* Implementation of cos and sin

* Implementation of tan, asin, acos, atan

* Simplifications in derivate expressions

* Additional tests for sinh, cosh and tanh
  • Loading branch information
arbimo authored and avi-stripe committed Feb 26, 2019
1 parent 3026eae commit 673408e
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,16 @@ private object Gradient {
Real.zero,
Real.zero,
gradient.toReal * child.original / child)
case NoOp => gradient.toReal
case NoOp => gradient.toReal
case SinOp => gradient.toReal * child.original.cos
case CosOp => gradient.toReal * (Real.zero - child.original.sin)
case TanOp => gradient.toReal / child.original.cos.pow(2)
case AsinOp =>
gradient.toReal / (Real.one - child.original.pow(2)).pow(0.5)
case AcosOp =>
-gradient.toReal / (Real.one - child.original.pow(2)).pow(0.5)
case AtanOp =>
gradient.toReal / (Real.one + child.original.pow(2))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ sealed trait Real {
def +(other: Real): Real = RealOps.add(this, other)
def *(other: Real): Real = RealOps.multiply(this, other)

def -(other: Real): Real = this + (other * -1)
def unary_- : Real = this * (-1)
def -(other: Real): Real = this + (-other)
def /(other: Real): Real = RealOps.divide(this, other)

def min(other: Real): Real = RealOps.min(this, other)
Expand All @@ -30,6 +31,18 @@ sealed trait Real {
def exp: Real = RealOps.unary(this, ir.ExpOp)
def log: Real = RealOps.unary(this, ir.LogOp)

def sin: Real = RealOps.unary(this, ir.SinOp)
def cos: Real = RealOps.unary(this, ir.CosOp)
def tan: Real = RealOps.unary(this, ir.TanOp)

def asin: Real = RealOps.unary(this, ir.AsinOp)
def acos: Real = RealOps.unary(this, ir.AcosOp)
def atan: Real = RealOps.unary(this, ir.AtanOp)

def sinh: Real = (this.exp - (-this).exp) / 2
def cosh: Real = (this.exp + (-this).exp) / 2
def tanh: Real = this.sinh / this.cosh

//because abs does not have a smooth derivative, try to avoid using it
def abs: Real = RealOps.unary(this, ir.AbsOp)

Expand Down Expand Up @@ -89,9 +102,11 @@ object Real {
private[compute] val BigZero = BigDecimal(0.0)
private[compute] val BigOne = BigDecimal(1.0)
private[compute] val BigTwo = BigDecimal(2.0)
private[compute] val BigPi = BigDecimal(math.Pi)
val zero: Real = Constant(BigZero)
val one: Real = Constant(BigOne)
val two: Real = Constant(BigTwo)
val Pi: Real = Constant(BigPi)
val infinity: Real = Infinity
val negInfinity: Real = NegInfinity
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,60 @@ private[compute] object RealOps {

def unary(original: Real, op: UnaryOp): Real =
original match {
case Infinity => Infinity
case Infinity =>
op match {
case ExpOp => Infinity
case LogOp => Infinity
case AbsOp => Infinity
case SinOp =>
throw new ArithmeticException(
"No limit for 'sin' at positive infinity")
case CosOp =>
throw new ArithmeticException(
"No limit for 'cos' at positive infinity")
case TanOp =>
throw new ArithmeticException(
"No limit for 'tan' at positive infinity")
case AcosOp => throw new ArithmeticException("acos undefined above 1")
case AsinOp => throw new ArithmeticException("asin undefined above 1")
case AtanOp => Real.Pi / 2
case NoOp => Infinity
}
case NegInfinity =>
op match {
case ExpOp => Real.zero
case LogOp =>
throw new ArithmeticException(
"Cannot take the log of a negative number")
case AbsOp => Infinity
case NoOp => original
case SinOp =>
throw new ArithmeticException(
"No limit for 'sin' at negative infinity")
case CosOp =>
throw new ArithmeticException(
"No limit for 'cos' at negative infinity")
case TanOp =>
throw new ArithmeticException(
"No limit for 'tan' at negative infinity")
case AcosOp =>
throw new ArithmeticException("acos undefined below -1")
case AsinOp =>
throw new ArithmeticException("asin undefined below -1")
case AtanOp => -Real.Pi / 2
case NoOp => original
}
case Constant(Real.BigZero) =>
op match {
case ExpOp => Real.one
case LogOp => NegInfinity
case AbsOp => Real.zero
case NoOp => original
case ExpOp => Real.one
case LogOp => NegInfinity
case AbsOp => Real.zero
case SinOp => Real.zero
case CosOp => Real.one
case TanOp => Real.zero
case AsinOp => Real.zero
case AcosOp => Real.Pi / 2
case AtanOp => Real.zero
case NoOp => original
}
case Constant(value) =>
op match {
Expand All @@ -32,8 +70,14 @@ private[compute] object RealOps {
"Cannot take the log of " + value.toDouble)
else
Real(Math.log(value.toDouble))
case AbsOp => Real(value.abs)
case NoOp => original
case AbsOp => Real(value.abs)
case SinOp => Real(Math.sin(value.toDouble))
case CosOp => Real(Math.cos(value.toDouble))
case TanOp => Real(Math.tan(value.toDouble))
case AsinOp => Real(Math.asin(value.toDouble))
case AcosOp => Real(Math.acos(value.toDouble))
case AtanOp => Real(Math.atan(value.toDouble))
case NoOp => original
}
case nc: NonConstant =>
val opt = (op, nc) match {
Expand Down
14 changes: 10 additions & 4 deletions rainier-core/src/main/scala/com/stripe/rainier/ir/IRViz.scala
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,16 @@ object IRViz {

def opLabel(op: UnaryOp): String =
op match {
case ExpOp => "exp"
case LogOp => "ln"
case AbsOp => "abs"
case NoOp => "nop"
case ExpOp => "exp"
case LogOp => "ln"
case AbsOp => "abs"
case SinOp => "sin"
case CosOp => "cos"
case TanOp => "tan"
case AsinOp => "asin"
case AcosOp => "acos"
case AtanOp => "atan"
case NoOp => "nop"
}

def opLabel(op: BinaryOp): String =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,16 @@ private trait MethodGenerator {

def unaryOp(op: UnaryOp): Unit = {
(op match {
case LogOp => Some(("java/lang/Math", "log"))
case ExpOp => Some(("java/lang/Math", "exp"))
case AbsOp => Some(("java/lang/Math", "abs"))
case NoOp => None
case LogOp => Some(("java/lang/Math", "log"))
case ExpOp => Some(("java/lang/Math", "exp"))
case AbsOp => Some(("java/lang/Math", "abs"))
case CosOp => Some(("java/lang/Math", "cos"))
case SinOp => Some(("java/lang/Math", "sin"))
case TanOp => Some(("java/lang/Math", "tan"))
case AsinOp => Some(("java/lang/Math", "asin"))
case AcosOp => Some(("java/lang/Math", "acos"))
case AtanOp => Some(("java/lang/Math", "atan"))
case NoOp => None
}).foreach {
case (className, methodName) =>
methodNode.visitMethodInsn(INVOKESTATIC,
Expand Down
8 changes: 8 additions & 0 deletions rainier-core/src/main/scala/com/stripe/rainier/ir/Ops.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,11 @@ case object ExpOp extends UnaryOp
case object LogOp extends UnaryOp
case object AbsOp extends UnaryOp
case object NoOp extends UnaryOp

case object SinOp extends UnaryOp
case object CosOp extends UnaryOp
case object TanOp extends UnaryOp

case object AsinOp extends UnaryOp
case object AcosOp extends UnaryOp
case object AtanOp extends UnaryOp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import org.scalatest._
import com.stripe.rainier.core._
import scala.util.{Try, Success, Failure}
class RealTest extends FunSuite {
def run(description: String, testDeriv: Double => Boolean = _ => true)(
fn: Real => Real): Unit = {
def run(description: String,
defined: Double => Boolean = _ => true,
derivable: Double => Boolean = _ => true,
reference: Double => Double = null)(fn: Real => Real): Unit = {
test(description) {
val x = new Variable
val result = fn(x)
Expand All @@ -17,21 +19,25 @@ class RealTest extends FunSuite {
case Success(Constant(bd)) => bd.toDouble
case Failure(_: ArithmeticException) => 0.0 / 0.0
case Failure(_: NumberFormatException) => 0.0 / 0.0
case Failure(e) => throw e
case x => sys.error("Non-constant value " + x)
}

val c = Compiler(200, 100).compile(List(x), result)
val dc = Compiler(200, 100).compile(List(x), deriv)
List(1.0, 0.0, -1.0, 2.0, -2.0, 0.5, -0.5).foreach { n =>
List(1.0, 0.0, -1.0, 2.0, -2.0, 0.5, -0.5).filter(defined).foreach { n =>
val constant = evalAt(n)
if (reference != null) {
assertWithinEpsilon(constant, reference(n), s"[c/ref, n=$n]")
}
val eval = new Evaluator(Map(x -> n))
val withVar = eval.toDouble(result)
assertWithinEpsilon(constant, withVar, s"[c/ev, n=$n]")
val compiled = c(Array(n))
assertWithinEpsilon(withVar, compiled, s"[ev/ir, n=$n]")

// derivatives of automated differentiation vs numeric differentiation
if (testDeriv(n)) {
if (derivable(n)) {
val dx = 10E-6
val numDiff = (evalAt(n + dx) - evalAt(n - dx)) / (dx * 2)
val diffWithVar = eval.toDouble(deriv)
Expand Down Expand Up @@ -70,26 +76,56 @@ class RealTest extends FunSuite {
run("log") { x =>
x.abs.log
}
run("sin", reference = math.sin) { x =>
x.sin
}
run("cos", reference = math.cos) { x =>
x.cos
}
run("tan", reference = math.tan) { x =>
x.tan
}
run("asin", defined = x => x > -1 && x < 1, reference = math.asin) { x =>
x.asin
}
run("acos", defined = x => x > -1 && x < 1, reference = math.acos) { x =>
x.acos
}
run("atan", reference = math.atan) { x =>
x.atan
}
run("sinh", reference = math.sinh) { x =>
x.sinh
}
run("cosh", reference = math.cosh) { x =>
x.cosh
}
run("tanh", reference = math.tanh) { x =>
x.tanh
}
run("cos(x^2)") { x =>
(x * x).cos
}
run("temp") { x =>
val t = x * 3
t + t
}
run("abs") { x =>
x.abs
}
run("max(x, 0)", testDeriv = _ != 0) { x =>
run("max(x, 0)", derivable = _ != 0) { x =>
x.max(0)
}
run("max(x, x)") { x =>
x.max(x)
}
run("x > 0 ? x^2 : 1", testDeriv = _ != 0) { x =>
run("x > 0 ? x^2 : 1", derivable = _ != 0) { x =>
Real.gt(x, 0, x * x, 1)
}
run("x > 0 ? 1 : x + 1", testDeriv = _ != 0) { x =>
run("x > 0 ? 1 : x + 1", derivable = _ != 0) { x =>
Real.gt(x, 0, 1, x + 1)
}
run("x > 0 ? x^2 : x + 1", _ != 0) { x =>
run("x > 0 ? x^2 : x + 1", derivable = _ != 0) { x =>
Real.gt(x, 0, x * x, x + 1)
}

Expand All @@ -108,7 +144,7 @@ class RealTest extends FunSuite {
Real.one / (x.exp + 1)
}

run("log x^2", testDeriv = _ != 0) { x =>
run("log x^2", derivable = _ != 0) { x =>
x.pow(2).log
}

Expand All @@ -124,20 +160,20 @@ class RealTest extends FunSuite {
(x * x * x))
}

run("lookup", testDeriv = _ => false) { x => // not derivable
run("lookup", derivable = _ => false) { x => // not derivable
val i = x.abs * 2 //should be a non-negative whole number
Lookup(i, Real.seq(List(0, 1, 2, 3, 4)))
}

val exponents = scala.util.Random.shuffle(-40.to(40))
run("exponent sums", testDeriv = _ != 0) { x =>
run("exponent sums", derivable = _ != 0) { x =>
exponents.foldLeft(x) {
case (a, e) =>
(a + x.pow(e)) * x
}
}

run("pow", testDeriv = _ >= 0) { x =>
run("pow", derivable = _ >= 0) { x =>
x.pow(x)
}

Expand Down

0 comments on commit 673408e

Please sign in to comment.