Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimise boilerplate generators, use instance constructors #3871

Merged
merged 10 commits into from
Jul 15, 2022
8 changes: 8 additions & 0 deletions algebra-core/src/main/scala/algebra/ring/Rig.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,12 @@ trait Rig[@sp(Int, Long, Float, Double) A] extends Any with Semiring[A] with Mul

object Rig extends AdditiveMonoidFunctions[Rig] with MultiplicativeMonoidFunctions[Rig] {
@inline final def apply[A](implicit ev: Rig[A]): Rig[A] = ev

@inline private[algebra] def instance[A](z: A, o: A, add: (A, A) => A, mul: (A, A) => A): Rig[A] =
new Rig[A] {
val zero: A = z
val one: A = o
def plus(x: A, y: A): A = add(x, y)
def times(x: A, y: A): A = mul(x, y)
}
}
9 changes: 9 additions & 0 deletions algebra-core/src/main/scala/algebra/ring/Ring.scala
Original file line number Diff line number Diff line change
Expand Up @@ -115,4 +115,13 @@ trait RingFunctions[R[T] <: Ring[T]] extends AdditiveGroupFunctions[R] with Mult

object Ring extends RingFunctions[Ring] {
@inline final def apply[A](implicit ev: Ring[A]): Ring[A] = ev

@inline private[algebra] def instance[A](z: A, o: A, neg: A => A, add: (A, A) => A, mul: (A, A) => A): Ring[A] =
new Ring[A] {
val zero: A = z
val one: A = o
def negate(x: A): A = neg(x)
def plus(x: A, y: A): A = add(x, y)
def times(x: A, y: A): A = mul(x, y)
}
}
8 changes: 8 additions & 0 deletions algebra-core/src/main/scala/algebra/ring/Rng.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,12 @@ trait Rng[@sp(Int, Long, Float, Double) A] extends Any with Semiring[A] with Add

object Rng extends AdditiveGroupFunctions[Rng] with MultiplicativeSemigroupFunctions[Rng] {
@inline final def apply[A](implicit ev: Rng[A]): Rng[A] = ev

@inline private[algebra] def instance[A](z: A, neg: A => A, add: (A, A) => A, mul: (A, A) => A): Rng[A] =
new Rng[A] {
val zero: A = z
def negate(x: A): A = neg(x)
def plus(x: A, y: A): A = add(x, y)
def times(x: A, y: A): A = mul(x, y)
}
}
7 changes: 7 additions & 0 deletions algebra-core/src/main/scala/algebra/ring/Semiring.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,11 @@ trait Semiring[@sp(Int, Long, Float, Double) A]

object Semiring extends AdditiveMonoidFunctions[Semiring] with MultiplicativeSemigroupFunctions[Semiring] {
@inline final def apply[A](implicit ev: Semiring[A]): Semiring[A] = ev

@inline private[algebra] def instance[A](z: A, add: (A, A) => A, mul: (A, A) => A): Semiring[A] =
new Semiring[A] {
val zero: A = z
def plus(x: A, y: A): A = add(x, y)
def times(x: A, y: A): A = mul(x, y)
}
}
10 changes: 10 additions & 0 deletions kernel/src/main/scala/cats/kernel/CommutativeGroup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,14 @@ object CommutativeGroup extends GroupFunctions[CommutativeGroup] {
* Access an implicit `CommutativeGroup[A]`.
*/
@inline final def apply[A](implicit ev: CommutativeGroup[A]): CommutativeGroup[A] = ev

/**
* Create a `CommutativeGroup` instance from the given inverse and combine functions and empty value.
*/
@inline def instance[A](emp: A, inv: A => A, cmb: (A, A) => A): CommutativeGroup[A] =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if instead we make an instance of InvariantFunctor[CommutativeGroup] and then we use the InvariantFunctor instances in the tuple code Gen?

Copy link
Member Author

@joroKr21 joroKr21 Jul 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You probably mean InvariantSemigroupal so that we can use tupledN right? That's doable but it would mean more allocations and reduced performance at runtime (going from (a, b, c, d) to (a, (b, (c, d))) and back) so I'm not sure that would be acceptable for cats-kernel which is also used by algebra.

new CommutativeGroup[A] {
val empty = emp
def inverse(a: A) = inv(a)
def combine(x: A, y: A) = cmb(x, y)
}
}
10 changes: 10 additions & 0 deletions kernel/src/main/scala/cats/kernel/Group.scala
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,14 @@ object Group extends GroupFunctions[Group] {
* Access an implicit `Group[A]`.
*/
@inline final def apply[A](implicit ev: Group[A]): Group[A] = ev

/**
* Create a `Group` instance from the given inverse and combine functions and empty value.
*/
@inline def instance[A](emp: A, inv: A => A, cmb: (A, A) => A): Group[A] =
new Group[A] {
val empty = emp
def inverse(a: A) = inv(a)
def combine(x: A, y: A) = cmb(x, y)
}
}
12 changes: 9 additions & 3 deletions kernel/src/main/scala/cats/kernel/Hash.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,17 @@ object Hash extends HashFunctions[Hash] {
def eqv(x: A, y: A) = x == y
}

/**
* Create a `Hash` instance from the given hash and equality functions.
*/
@inline def instance[A](h: A => Int, e: (A, A) => Boolean): Hash[A] =
new Hash[A] {
def hash(x: A) = h(x)
def eqv(x: A, y: A) = e(x, y)
}
}

trait HashToHashingConversion {
implicit def catsKernelHashToHashing[A](implicit ev: Hash[A]): Hashing[A] =
new Hashing[A] {
override def hash(x: A): Int = ev.hash(x)
}
ev.hash(_)
}
60 changes: 29 additions & 31 deletions project/AlgebraBoilerplate.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ object AlgebraBoilerplate {
val synVals = (0 until arity).map(n => s"a$n")
val `A..N` = synTypes.mkString(", ")
val `a..n` = synVals.mkString(", ")
val `_.._` = Seq.fill(arity)("_").mkString(", ")
val `(A..N)` = if (arity == 1) "Tuple1[A0]" else synTypes.mkString("(", ", ", ")")
val `(_.._)` = if (arity == 1) "Tuple1[_]" else Seq.fill(arity)("_").mkString("(", ", ", ")")
val `(a..n)` = if (arity == 1) "Tuple1(a)" else synVals.mkString("(", ", ", ")")
}

Expand Down Expand Up @@ -86,32 +84,32 @@ object AlgebraBoilerplate {
import tv._

def constraints(constraint: String) =
synTypes.map(tpe => s"${tpe}: ${constraint}[${tpe}]").mkString(", ")
synTypes.map(tpe => s"$tpe: $constraint[$tpe]").mkString(", ")

def tuple(results: TraversableOnce[String]) = {
val resultsVec = results.toVector
val a = synTypes.size
val r = s"${0.until(a).map(i => resultsVec(i)).mkString(", ")}"
if (a == 1) "Tuple1(" ++ r ++ ")"
else s"(${r})"
else s"($r)"
}

def binMethod(name: String) =
synTypes.zipWithIndex.iterator.map { case (tpe, i) =>
val j = i + 1
s"${tpe}.${name}(x._${j}, y._${j})"
s"$tpe.$name(x._$j, y._$j)"
}

def binTuple(name: String) =
tuple(binMethod(name))

def unaryTuple(name: String) = {
val m = synTypes.zipWithIndex.map { case (tpe, i) => s"${tpe}.${name}(x._${i + 1})" }
val m = synTypes.zipWithIndex.map { case (tpe, i) => s"$tpe.$name(x._${i + 1})" }
tuple(m)
}

def nullaryTuple(name: String) = {
val m = synTypes.map(tpe => s"${tpe}.${name}")
val m = synTypes.map(tpe => s"$tpe.$name")
tuple(m)
}

Expand All @@ -124,36 +122,36 @@ object AlgebraBoilerplate {
|trait TupleInstances extends cats.kernel.instances.TupleInstances {
-
- implicit def tuple${arity}Rig[${`A..N`}](implicit ${constraints("Rig")}): Rig[${`(A..N)`}] =
- new Rig[${`(A..N)`}] {
- def one: ${`(A..N)`} = ${nullaryTuple("one")}
- def plus(x: ${`(A..N)`}, y: ${`(A..N)`}): ${`(A..N)`} = ${binTuple("plus")}
- def times(x: ${`(A..N)`}, y: ${`(A..N)`}): ${`(A..N)`} = ${binTuple("times")}
- def zero: ${`(A..N)`} = ${nullaryTuple("zero")}
- }
- Rig.instance(
- ${nullaryTuple("zero")},
- ${nullaryTuple("one")},
- (x, y) => ${binTuple("plus")},
- (x, y) => ${binTuple("times")}
- )
-
- implicit def tuple${arity}Ring[${`A..N`}](implicit ${constraints("Ring")}): Ring[${`(A..N)`}] =
- new Ring[${`(A..N)`}] {
- def one: ${`(A..N)`} = ${nullaryTuple("one")}
- def plus(x: ${`(A..N)`}, y: ${`(A..N)`}): ${`(A..N)`} = ${binTuple("plus")}
- def times(x: ${`(A..N)`}, y: ${`(A..N)`}): ${`(A..N)`} = ${binTuple("times")}
- def zero: ${`(A..N)`} = ${nullaryTuple("zero")}
- def negate(x: ${`(A..N)`}): ${`(A..N)`} = ${unaryTuple("negate")}
- }
- Ring.instance(
- ${nullaryTuple("zero")},
- ${nullaryTuple("one")},
- x => ${unaryTuple("negate")},
- (x, y) => ${binTuple("plus")},
- (x, y) => ${binTuple("times")}
- )
-
- implicit def tuple${arity}Rng[${`A..N`}](implicit ${constraints("Rng")}): Rng[${`(A..N)`}] =
- new Rng[${`(A..N)`}] {
- def plus(x: ${`(A..N)`}, y: ${`(A..N)`}): ${`(A..N)`} = ${binTuple("plus")}
- def times(x: ${`(A..N)`}, y: ${`(A..N)`}): ${`(A..N)`} = ${binTuple("times")}
- def zero: ${`(A..N)`} = ${nullaryTuple("zero")}
- def negate(x: ${`(A..N)`}): ${`(A..N)`} = ${unaryTuple("negate")}
- }
- Rng.instance(
- ${nullaryTuple("zero")},
- x => ${unaryTuple("negate")},
- (x, y) => ${binTuple("plus")},
- (x, y) => ${binTuple("times")}
- )
-
- implicit def tuple${arity}Semiring[${`A..N`}](implicit ${constraints("Semiring")}): Semiring[${`(A..N)`}] =
- new Semiring[${`(A..N)`}] {
- def plus(x: ${`(A..N)`}, y: ${`(A..N)`}): ${`(A..N)`} = ${binTuple("plus")}
- def times(x: ${`(A..N)`}, y: ${`(A..N)`}): ${`(A..N)`} = ${binTuple("times")}
- def zero: ${`(A..N)`} = ${nullaryTuple("zero")}
- }
- Semiring.instance(
- ${nullaryTuple("zero")},
- (x, y) => ${binTuple("plus")},
- (x, y) => ${binTuple("times")}
- )
|}
"""
}
Expand Down
16 changes: 4 additions & 12 deletions project/Boilerplate.scala
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,6 @@ object Boilerplate {
if (arity <= 2) "(*, *)"
else `A..(N - 2)`.mkString("(", ", ", ", *, *)")
val `a..(n - 1)` = (0 until (arity - 1)).map(n => s"a$n")
val `fa._1..fa._(n - 2)` =
if (arity <= 2) "" else (0 until (arity - 2)).map(n => s"fa._${n + 1}").mkString("", ", ", ", ")
val `pure(fa._1..(n - 2))` =
if (arity <= 2) "" else (0 until (arity - 2)).map(n => s"G.pure(fa._${n + 1})").mkString("", ", ", ", ")
val `a0, a(n - 1)` = if (arity <= 1) "" else `a..(n - 1)`.mkString(", ")
val `[A0, A(N - 1)]` = if (arity <= 1) "" else `A..(N - 1)`.mkString("[", ", ", "]")
val `(A0, A(N - 1))` =
Expand All @@ -87,19 +83,15 @@ object Boilerplate {
val `(A..N - 1, *)` =
if (arity == 1) "Tuple1"
else `A..(N - 1)`.mkString("(", ", ", ", *)")
val `(fa._1..(n - 1))` =
if (arity <= 1) "Tuple1.apply" else (0 until (arity - 1)).map(n => s"fa._${n + 1}").mkString("(", ", ", ", _)")

def `A0, A(N - 1)&`(a: String): String =
if (arity <= 1) s"Tuple1[$a]" else `A..(N - 1)`.mkString("(", ", ", s", $a)")

def `fa._1..(n - 1) & `(a: String): String =
if (arity <= 1) s"Tuple1($a)" else (0 until (arity - 1)).map(n => s"fa._${n + 1}").mkString("(", ", ", s", $a)")

def `constraints A..N`(c: String): String = synTypes.map(tpe => s"$tpe: $c[$tpe]").mkString("(implicit ", ", ", ")")
def `constraints A..N`(c: String): String =
synTypes.map(tpe => s"$tpe: $c[$tpe]").mkString("(implicit ", ", ", ")")
def `constraints A..(N-1)`(c: String): String =
if (arity <= 1) "" else `A..(N - 1)`.map(tpe => s"$tpe: $c[$tpe]").mkString("(implicit ", ", ", ")")
def `parameters A..(N-1)`(c: String): String = `A..(N - 1)`.map(tpe => s"$tpe: $c[$tpe]").mkString(", ")
def `parameters A..(N-1)`(c: String): String =
`A..(N - 1)`.map(tpe => s"$tpe: $c[$tpe]").mkString(", ")
}

trait Template {
Expand Down
Loading