Skip to content

Commit

Permalink
Add derivation of DeltaSurgeon for sum types
Browse files Browse the repository at this point in the history
  • Loading branch information
ckuessner committed Jul 23, 2024
1 parent d1ed843 commit fafeb2d
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import rdts.base.Bottom
import rdts.dotted.{Dotted, Obrem}
import rdts.time.Dots

import scala.compiletime.{constValue, erasedValue, summonAll}
import scala.compiletime.*
import scala.deriving.Mirror

case class IsolatedDeltaParts(inner: Map[String, IsolatedDeltaParts] | Array[Byte]) {
Expand Down Expand Up @@ -46,22 +46,29 @@ object DeltaSurgeon {
inline def apply[T](using deltaSurgeon: DeltaSurgeon[T]): DeltaSurgeon[T] = deltaSurgeon

// From https://blog.philipp-martini.de/blog/magic-mirror-scala3/
private inline def getFactorLabels[A <: Tuple]: List[String] = inline erasedValue[A] match {
private inline def getLabels[A <: Tuple]: List[String] = inline erasedValue[A] match {
case _: EmptyTuple => Nil
case _: (t *: ts) => constValue[t].toString :: getFactorLabels[ts]
case _: (t *: ts) => constValue[t].toString :: getLabels[ts]
}

inline def derived[T <: Product](using pm: Mirror.ProductOf[T], productBottom: Bottom[T]): DeltaSurgeon[T] =
val factorBottoms = summonAll[Tuple.Map[pm.MirroredElemTypes, Bottom]].toIArray.map(_.asInstanceOf[Bottom[Any]])
val factorSurgeons =
summonAll[Tuple.Map[pm.MirroredElemTypes, DeltaSurgeon]].toIArray.map(_.asInstanceOf[DeltaSurgeon[Any]])
val factorLabels = getFactorLabels[pm.MirroredElemLabels]
ProductTypeSurgeon[T](pm, productBottom, factorLabels, factorBottoms, factorSurgeons)
inline def derived[T](using m: Mirror.Of[T], bottom: Bottom[T]): DeltaSurgeon[T] =
val elementLabels = getLabels[m.MirroredElemLabels].toArray
val elementBottoms = summonAll[Tuple.Map[m.MirroredElemTypes, Bottom]].toIArray.map(_.asInstanceOf[Bottom[Any]])
val elementSurgeons =
summonAll[Tuple.Map[m.MirroredElemTypes, DeltaSurgeon]].toIArray.map(_.asInstanceOf[DeltaSurgeon[Any]])
inline m match
case sumMirror: Mirror.SumOf[T] =>
SumTypeDeltaSurgeon[T](sumMirror, bottom, elementLabels, elementBottoms, elementSurgeons)
case productMirror: Mirror.ProductOf[T] =>
ProductTypeSurgeon[T](productMirror, bottom, elementLabels, elementBottoms, elementSurgeons)

private inline given sumElemLabels[T](using sm: Mirror.SumOf[T]): List[Any] =
constValueTuple[sm.MirroredElemLabels].map[[X] =>> String]([X] => (x: X) => x.toString).toList

class ProductTypeSurgeon[T](
pm: Mirror.ProductOf[T],
productBottom: Bottom[T], // The bottom of the product (derivable as the product of bottoms)
factorLabels: List[String], // Maps the factor label to the factor index
factorLabels: Array[String], // Maps the factor label to the factor index
factorBottoms: IArray[Bottom[Any]], // The Bottom TypeClass instance for each factor
factorSurgeons: IArray[DeltaSurgeon[Any]], // The DeltaSurgeon TypeClass instance for each factor
) extends DeltaSurgeon[T]:
Expand Down Expand Up @@ -101,15 +108,46 @@ object DeltaSurgeon {
case byteArray => ???
}

import lofi_acl.sync.JsoniterCodecs.dotsCodec
class SumTypeDeltaSurgeon[T](
sm: Mirror.SumOf[T],
sumBottom: Bottom[T], // The bottom of the product (derivable as the product of bottoms)
elementLabels: Array[String], // Maps the factor label to the factor index
elementBottoms: IArray[Bottom[Any]], // The Bottom TypeClass instance for each factor
elementSurgeons: IArray[DeltaSurgeon[Any]], // The DeltaSurgeon TypeClass instance for each factor
) extends DeltaSurgeon[T]:
private val ordinalLookup = elementLabels.zipWithIndex.toMap

given dotsDeltaSurgeon: DeltaSurgeon[Dots] = ofTerminalValue[Dots]
given dottedDeltaSurgeon[T: DeltaSurgeon: Bottom]: DeltaSurgeon[Dotted[T]] = DeltaSurgeon.derived
given obremDeltaSurgeon[T: DeltaSurgeon: Bottom]: DeltaSurgeon[Obrem[T]] = DeltaSurgeon.derived
override def isolate(delta: T): IsolatedDeltaParts = {
val ordinal = sm.ordinal(delta)
val label = elementLabels(sm.ordinal(delta))
val isolatedElement = elementSurgeons(ordinal).isolate(delta)
IsolatedDeltaParts(Map(label -> isolatedElement))
}

override def recombine(parts: IsolatedDeltaParts): T = {
parts.inner match
case map: Map[String, IsolatedDeltaParts] =>
require(map.size == 1)
map.head match
case (elementType, element) =>
val ordinal = ordinalLookup(elementType)
elementSurgeons(ordinal).recombine(element).asInstanceOf[T]
case arr: Array[Byte] => ???
}

// Used for values that should not be further isolated
def ofTerminalValue[V: Bottom: JsonValueCodec]: DeltaSurgeon[V] = new TerminalValueDeltaSurgeon[V]

// TODO: could be used inside of sum type derivation to avoid the need to specify instances externally
// TODO: Restrict type to case objects
def ofCaseObject[V](obj: V): DeltaSurgeon[V] = {
new DeltaSurgeon[V]:
override def isolate(delta: V): IsolatedDeltaParts = IsolatedDeltaParts(Array.empty)
override def recombine(parts: IsolatedDeltaParts): V = parts.inner match
case array: Array[Byte] if array.isEmpty => obj
case _ => ???
}

private class TerminalValueDeltaSurgeon[V: Bottom: JsonValueCodec] extends DeltaSurgeon[V] {
override def isolate(delta: V): IsolatedDeltaParts =
if Bottom[V].isEmpty(delta) then IsolatedDeltaParts(Map.empty)
Expand All @@ -122,4 +160,15 @@ object DeltaSurgeon {
case serializedValue: Array[Byte] => readFromArray(serializedValue)
}

import lofi_acl.sync.JsoniterCodecs.dotsCodec

given dotsDeltaSurgeon: DeltaSurgeon[Dots] = ofTerminalValue[Dots]
given dottedDeltaSurgeon[T: DeltaSurgeon: Bottom]: DeltaSurgeon[Dotted[T]] = DeltaSurgeon.derived
given obremDeltaSurgeon[T: DeltaSurgeon: Bottom]: DeltaSurgeon[Obrem[T]] = DeltaSurgeon.derived
given noneBottom: Bottom[None.type] = Bottom.provide(None)
given noneDeltaSurgeon: DeltaSurgeon[None.type] = DeltaSurgeon.ofCaseObject(None)
given optionBottom[T]: Bottom[Option[T]] = Bottom.provide(None)
given someBottom[T: Bottom]: Bottom[Some[T]] = Bottom.derived
given someSurgeon[T: Bottom: DeltaSurgeon]: DeltaSurgeon[Some[T]] = DeltaSurgeon.derived

}
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@ package lofi_acl.access

import com.github.plokhotnyuk.jsoniter_scala.core.{JsonValueCodec, readFromArray, writeToArray}
import com.github.plokhotnyuk.jsoniter_scala.macros.JsonCodecMaker
import lofi_acl.access.DeltaSurgeon.getLabels
import lofi_acl.access.DeltaSurgeonTest.given
import lofi_acl.access.Permission.{ALLOW, PARTIAL}
import lofi_acl.access.PermissionTree.{allow, empty}
import munit.FunSuite
import org.junit.Assert
import rdts.base
import rdts.base.Bottom

import scala.compiletime.{constValue, constValueTuple, erasedValue, summonAll}
import scala.deriving.Mirror

case class A(a: String, b: B)
case class B(c: String)

Expand Down Expand Up @@ -83,4 +88,39 @@ class DeltaSurgeonTest extends FunSuite {
A("", B(""))
)
}

private inline given sumElemLabels[T](using sm: Mirror.SumOf[T]): Tuple.Map[sm.MirroredElemLabels, [X] =>> String] =
constValueTuple[sm.MirroredElemLabels].map[[X] =>> String]([X] => (x: X) => x.toString)

inline def testDerive[T](using m: Mirror.Of[T], bottom: Bottom[T]): Unit =
val elementSurgeons =
summonAll[Tuple.Map[m.MirroredElemTypes, DeltaSurgeon]].toIArray.map(_.asInstanceOf[DeltaSurgeon[Any]])
val elementBottoms = summonAll[Tuple.Map[m.MirroredElemTypes, Bottom]].toIArray.map(_.asInstanceOf[Bottom[Any]])
val elementLabels = getLabels[m.MirroredElemLabels].toArray
println(elementSurgeons)
println(elementBottoms)
println(elementLabels)

private inline def getLabels[A <: Tuple]: List[String] = inline erasedValue[A] match {
case _: EmptyTuple => Nil
case _: (t *: ts) => constValue[t].toString :: getLabels[ts]
}

import DeltaSurgeon.given
private given stringBottom: Bottom[String] = Bottom.provide("")
private val optionSurgeon = DeltaSurgeon.derived[Option[String]]

test("derivation of sums") {
val some: Option[String] = Some("Test")
val none: Option[String] = None

assertEquals(
optionSurgeon.recombine(optionSurgeon.isolate(some)),
some
)
assertEquals(
optionSurgeon.recombine(optionSurgeon.isolate(none)),
none
)
}
}

0 comments on commit fafeb2d

Please sign in to comment.