Skip to content

Commit

Permalink
implement basic BFT RDT
Browse files Browse the repository at this point in the history
  • Loading branch information
Kiibou-chan committed Jul 30, 2024
1 parent 7d1af04 commit 560a257
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 0 deletions.
115 changes: 115 additions & 0 deletions Modules/RDTs/src/main/scala/rdts/datatypes/BFT.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package rdts.datatypes

import rdts.base.Lattice.{mapLattice, *}
import rdts.base.{Bottom, Lattice, Uid}
import rdts.dotted.HasDots

import java.security.MessageDigest
import java.util
import scala.collection.mutable

trait Byteable[T] {

def toBytes(obj: T): Array[Byte]

}

object Byteable {

def toStringBased[T]: Byteable[T] = (obj: T) => obj.toString.getBytes

}

case class Hash(content: Array[Byte]) {
override def toString: String = s"[#${content.mkString("")}]"

override def canEqual(that: Any): Boolean = that.getClass == getClass

override def equals(obj: Any): Boolean = canEqual(obj) && util.Arrays.equals(obj.asInstanceOf[Hash].content, content)
}

case class BFTDelta[V](value: V, predecessors: Set[Hash], hash: Hash) {

def hashCorrect(using Byteable[V]): Boolean = {
BFT.hash(value, predecessors) == hash
}

}

object BFTDelta {

def apply[V](value: V, predecessors: Set[Hash])(using Byteable[V]): BFTDelta[V] =
BFTDelta(value, predecessors, BFT.hash(value, predecessors))

}

case class BFT[V] private[rdts](deltas: Set[BFTDelta[V]]) {

def value(using b: Bottom[V], lat: Lattice[V]): V = {
val graph = reverseGraph()

if (!graph.contains(None)) return b.empty

val worklist = mutable.Queue[BFTDelta[V]](graph(None).toList *)

val connected = mutable.Set[BFTDelta[V]]()

while (worklist.nonEmpty) {
val elem = worklist.dequeue()
connected.add(elem)
worklist.enqueueAll(graph.getOrElse(Some(elem.hash), Set.empty[BFTDelta[V]]))
}

connected.map(_.value).foldLeft(b.empty)((l, r) => l.merge(r))
}

/**
* Generates a BFT delta containing an update with a value delta. This assumes, that the RDT that's wrapped in the
* BFT generates deltas.
*
* @param f Function which takes the current state as input and returns a new delta.
* @return BFT containing one BFTDelta with the update.
*/
def update(f: V => V)(using Byteable[V], Lattice[V], Bottom[V]): BFT[V] = {
val newValue = f(value)

val delta = BFTDelta(newValue, heads, BFT.hash(newValue, heads))

BFT(Set(delta))
}

lazy val heads: Set[Hash] = deltas.filter { item => deltas.forall { a => !a.predecessors.contains(item.hash) } }.map(_.hash)

def reverseGraph(): Map[Option[Hash], Set[BFTDelta[V]]] = {
val reverseGraph = mutable.Map[Option[Hash], Set[BFTDelta[V]]]()

def addToGraph(from: Option[Hash], to: BFTDelta[V]) = {
reverseGraph.updateWith(from)(_.fold(Some(Set(to)))(it => Some(it + to)))
}

for (delta <- deltas) {
if (delta.predecessors.isEmpty) addToGraph(None, delta)
else for (hash <- delta.predecessors) addToGraph(Some(hash), delta)
}

reverseGraph.toMap
}

}

object BFT {

val digest: MessageDigest = MessageDigest.getInstance("SHA-256")

def apply[V](initial: V)(using Byteable[V]): BFT[V] = BFT(Set(BFTDelta(initial, Set.empty)))

def lattice[V](using lat: Lattice[V])(using Byteable[V]): Lattice[BFT[V]] = {
(left: BFT[V], right: BFT[V]) => {
BFT((left.deltas ++ right.deltas).filter(_.hashCorrect))
}
}

def hash[V](value: V, heads: Set[Hash])(using ch: Byteable[V]): Hash =
Hash(BFT.digest.digest(Array.concat(ch.toBytes(value) :: heads.toList.map(_.content) *)))

}
32 changes: 32 additions & 0 deletions Modules/RDTs/src/test/scala/test/rdts/bft/BFTTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package test.rdts.bft

import rdts.base.*
import rdts.datatypes.*

class BFTTest extends munit.ScalaCheckSuite {


given Byteable[GrowOnlyCounter] = (obj: GrowOnlyCounter) => obj.inner.toString.getBytes
given Lattice[BFT[GrowOnlyCounter]] = BFT.lattice

test("basic update") {
val id1 = LocalUid.gen()

val bottom = BFT(summon[Bottom[GrowOnlyCounter]].empty)

val u1 = bottom.update(_.inc()(using id1))

val res = bottom.merge(u1)

assertEquals(bottom.value.value, 0)
assertEquals(u1.value.value, 0)
assertEquals(res.value.value, 1)

assertEquals(u1.deltas.toList(0).predecessors.toList(0), bottom.deltas.toList(0).hash)

assertEquals(bottom.deltas.size, 1)
assertEquals(u1.deltas.size, 1)
assertEquals(res.deltas.size, 2)
}

}

0 comments on commit 560a257

Please sign in to comment.