Skip to content

Commit

Permalink
Map union and intersection
Browse files Browse the repository at this point in the history
  • Loading branch information
ericeil committed Dec 16, 2024
1 parent be55761 commit 9a78d9a
Show file tree
Hide file tree
Showing 7 changed files with 329 additions and 1 deletion.
152 changes: 152 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT
Applies a merge function to all entries in this Treap node.
*/
abstract fun getShallowMerger(merger: (K, V?, V?) -> V?): (S?, S?) -> S?
abstract fun getShallowUnionMerger(merger: (K, V, V) -> V): (S, S) -> S
abstract fun getShallowIntersectMerger(merger: (K, V, V) -> V): (S, S) -> S?

private fun containsEntry(entry: Map.Entry<K, V>): Boolean {
val key = entry.key
Expand Down Expand Up @@ -152,6 +154,52 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT
override operator fun iterator() = entrySequence().map { it.value }.iterator()
}

override fun union(m: Map<K, V>, merger: (K, V, V) -> V): TreapMap<K, V> =
m.useAsTreap(
{ otherTreap -> self.unionWith(otherTreap, getShallowUnionMerger(merger)) ?: clear() },
{ fallbackUnion(m, merger) }
)

override fun parallelUnion(m: Map<K, V>, parallelThresholdLog2: Int, merger: (K, V, V) -> V): TreapMap<K, V> =
m.useAsTreap(
{ otherTreap -> self.parallelUnionWith(otherTreap, parallelThresholdLog2, getShallowUnionMerger(merger)) ?: clear() },
{ fallbackUnion(m, merger) }
)

private fun fallbackUnion(m: Map<K, V>, merger: (K, V, V) -> V): TreapMap<K, V> {
var newThis = this as TreapMap<K, V>
for ((k, v) in m.entries) {
if (k in this) {
newThis = newThis + (k to merger(k, this[k]!!, v))
} else {
newThis = newThis + (k to v)
}
}
return newThis
}

override fun intersect(m: Map<K, V>, merger: (K, V, V) -> V): TreapMap<K, V> =
m.useAsTreap(
{ otherTreap -> self.intersectWith(otherTreap, getShallowIntersectMerger(merger)) ?: clear() },
{ fallbackIntersect(m, merger) }
)

override fun parallelIntersect(m: Map<K, V>, parallelThresholdLog2: Int, merger: (K, V, V) -> V): TreapMap<K, V> =
m.useAsTreap(
{ otherTreap -> self.parallelIntersectWith(otherTreap, parallelThresholdLog2, getShallowIntersectMerger(merger)) ?: clear() },
{ fallbackIntersect(m, merger) }
)

private fun fallbackIntersect(m: Map<K, V>, merger: (K, V, V) -> V): TreapMap<K, V> {
var newThis = clear()
for ((k, v) in m.entries) {
if (k in this) {
newThis = newThis + (k to merger(k, this[k]!!, v))
}
}
return newThis
}

/**
Merges the entries in `m` with the entries in this AbstractTreapMap, applying the "merger" function to get the
new values for each key.
Expand Down Expand Up @@ -491,3 +539,107 @@ private fun <@Treapable K, V, @Treapable S : AbstractTreapMap<K, V, S>> S?.merge
return newThis?.with(newLeft, newRight) ?: (newLeft join newRight)
}

internal fun <@Treapable K, V, @Treapable S : AbstractTreapMap<K, V, S>> S?.unionWith(
that: S?,
shallowUnion: (S, S) -> S
): S? =
notForking(this to that) {
unionWithImpl(that, shallowUnion)
}

internal fun <@Treapable K, V, @Treapable S : AbstractTreapMap<K, V, S>> S?.parallelUnionWith(
that: S?,
parallelThresholdLog2: Int,
shallowUnion: (S, S) -> S
): S? =
maybeForking(
this to that,
{
it.first.isApproximatelySmallerThanLog2(parallelThresholdLog2 - 1) &&
it.second.isApproximatelySmallerThanLog2(parallelThresholdLog2 - 1)
}
) {
unionWithImpl(that, shallowUnion)
}

context(ThresholdForker<Pair<S?, S?>>)
private fun <@Treapable K, V, @Treapable S : AbstractTreapMap<K, V, S>> S?.unionWithImpl(
that: S?,
shallowUnion: (S, S) -> S
): S? {
val (newLeft, newRight, newThis) = when {
this == null -> return that
that == null -> return this
this.comparePriorityTo(that) >= 0 -> {
val thatSplit = that.split(this)
fork(
this to that,
{ this.left.unionWithImpl(thatSplit.left, shallowUnion) },
{ this.right.unionWithImpl(thatSplit.right, shallowUnion) },
{ thatSplit.duplicate?.let { shallowUnion(this, it) } ?: this }
)
}
else -> {
val thisSplit = this.split(that)
fork(
this to that,
{ thisSplit.left.unionWithImpl(that.left, shallowUnion) },
{ thisSplit.right.unionWithImpl(that.right, shallowUnion) },
{ thisSplit.duplicate?.let { shallowUnion(it, that) } ?: that }
)
}
}
return newThis.with(newLeft, newRight)
}

internal fun <@Treapable K, V, @Treapable S : AbstractTreapMap<K, V, S>> S?.intersectWith(
that: S?,
shallowIntersect: (S, S) -> S?
): S? =
notForking(this to that) {
intersectWithImpl(that, shallowIntersect)
}

internal fun <@Treapable K, V, @Treapable S : AbstractTreapMap<K, V, S>> S?.parallelIntersectWith(
that: S?,
parallelThresholdLog2: Int,
shallowIntersect: (S, S) -> S?
): S? =
maybeForking(
this to that,
{
it.first.isApproximatelySmallerThanLog2(parallelThresholdLog2 - 1) &&
it.second.isApproximatelySmallerThanLog2(parallelThresholdLog2 - 1)
}
) {
intersectWithImpl(that, shallowIntersect)
}

context(ThresholdForker<Pair<S?, S?>>)
private fun <@Treapable K, V, @Treapable S : AbstractTreapMap<K, V, S>> S?.intersectWithImpl(
that: S?,
shallowIntersect: (S, S) -> S?
): S? {
val (newLeft, newRight, newThis) = when {
this == null || that == null -> return null
this.comparePriorityTo(that) >= 0 -> {
val thatSplit = that.split(this)
fork(
this to that,
{ this.left.intersectWithImpl(thatSplit.left, shallowIntersect) },
{ this.right.intersectWithImpl(thatSplit.right, shallowIntersect) },
{ thatSplit.duplicate?.let { shallowIntersect(this, it) } }
)
}
else -> {
val thisSplit = this.split(that)
fork(
this to that,
{ thisSplit.left.intersectWithImpl(that.left, shallowIntersect) },
{ thisSplit.right.intersectWithImpl(that.right, shallowIntersect) },
{ thisSplit.duplicate?.let { shallowIntersect(it, that) } }
)
}
}
return newThis?.with(newLeft, newRight) ?: (newLeft join newRight)
}
6 changes: 6 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ internal class EmptyTreapMap<@Treapable K, V> private constructor() : TreapMap<K
else -> put(key, v)
}

override fun union(m: Map<K, V>, merger: (K, V, V) -> V): TreapMap<K, V> = putAll(m)
override fun parallelUnion(m: Map<K, V>, parallelThresholdLog2: Int, merger: (K, V, V) -> V): TreapMap<K, V> = putAll(m)

override fun intersect(m: Map<K, V>, merger: (K, V, V) -> V): TreapMap<K, V> = this
override fun parallelIntersect(m: Map<K, V>, parallelThresholdLog2: Int, merger: (K, V, V) -> V): TreapMap<K, V> = this

override fun merge(m: Map<K, V>, merger: (K, V?, V?) -> V?): TreapMap<K, V> {
var map: TreapMap<K, V> = this
for ((key, value) in m) {
Expand Down
42 changes: 42 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,48 @@ internal class HashTreapMap<@Treapable K, V>(
}
}

override fun getShallowUnionMerger(
merger: (K, V, V) -> V
): (HashTreapMap<K, V>, HashTreapMap<K, V>) -> HashTreapMap<K, V> = { t1, t2 ->
var newPairs: KeyValuePairList.More<K, V>? = null
t1.forEachPair { (k, v1) ->
val v = if (t2.shallowContainsKey(k)) {
@Suppress("UNCHECKED_CAST")
merger(k, v1, t2.shallowGetValue(k) as V)
} else {
v1
}
newPairs = KeyValuePairList.More(k, v, newPairs)
}
t2.forEachPair { (k, v2) ->
if (!t1.shallowContainsKey(k)) {
newPairs = KeyValuePairList.More(k, v2, newPairs)
}
}
newPairs!!.let { firstPair ->
val newNode = HashTreapMap(firstPair.key, firstPair.value, firstPair.next, t1.left, t1.right)
if (newNode.shallowEquals(t1)) { t1 } else { newNode }
}
}

override fun getShallowIntersectMerger(
merger: (K, V, V) -> V
): (HashTreapMap<K, V>, HashTreapMap<K, V>) -> HashTreapMap<K, V>? = { t1, t2 ->
var newPairs: KeyValuePairList.More<K, V>? = null
t1.forEachPair { (k, v1) ->
if (t2.shallowContainsKey(k)) {
@Suppress("UNCHECKED_CAST")
val v2 = t2.shallowGetValue(k) as V
val v = merger(k, v1, v2)
newPairs = KeyValuePairList.More(k, v, newPairs)
}
}
newPairs?.let { firstPair ->
val newNode = HashTreapMap(firstPair.key, firstPair.value, firstPair.next, t1.left, t1.right)
if (newNode.shallowEquals(t1)) { t1 } else { newNode }
}
}

private inline fun KeyValuePairList<K, V>?.forEachPair(action: (KeyValuePairList<K, V>) -> Unit) {
var current = this
while (current != null) {
Expand Down
14 changes: 14 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,20 @@ internal class SortedTreapMap<@Treapable K, V>(

override fun arbitraryOrNull(): Map.Entry<K, V>? = MapEntry(key, value)

override fun getShallowUnionMerger(
merger: (K, V, V) -> V
): (SortedTreapMap<K, V>, SortedTreapMap<K, V>) -> SortedTreapMap<K, V> = { t1, t2 ->
val v = merger(t1.key, t1.value, t2.value)
SortedTreapMap(t1.key, v, t1.left, t1.right)
}

override fun getShallowIntersectMerger(
merger: (K, V, V) -> V
): (SortedTreapMap<K, V>, SortedTreapMap<K, V>) -> SortedTreapMap<K, V>? = { t1, t2 ->
val v = merger(t1.key, t1.value, t2.value)
SortedTreapMap(t1.key, v, t1.left, t1.right)
}

override fun getShallowMerger(merger: (K, V?, V?) -> V?): (SortedTreapMap<K, V>?, SortedTreapMap<K, V>?) -> SortedTreapMap<K, V>? = { t1, t2 ->
val k = (t1 ?: t2)!!.key
val v1 = t1?.value
Expand Down
4 changes: 3 additions & 1 deletion collect/src/main/kotlin/com/certora/collect/Treap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,9 @@ internal fun <@Treapable T, S : Treap<T, S>> Treap<T, S>?.split(key: TreapKey<T>
}
}
}
internal class Split<@Treapable T, S : Treap<T, S>>(var left: S?, var right: S?, var duplicate: S?)
internal class Split<@Treapable T, S : Treap<T, S>>(var left: S?, var right: S?, var duplicate: S?) {
override fun toString(): String = "Split(left=$left, right=$right, duplicate=$duplicate)"
}


/**
Expand Down
52 changes: 52 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/TreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,58 @@ public sealed interface TreapMap<K, V> : PersistentMap<K, V> {
*/
public fun forEachEntry(action: (Map.Entry<K, V>) -> Unit): Unit

/**
Produces a new map containing the keys from this map and another map [m].
If a key is present in just one of the maps, the resulting map will contain the key with the corresponding value
from that map. If a key is present in both maps, [merger] is called with the key, the value from this map, and
the value from [m], in that order, and the returned value will appear in the resulting map.
*/
public fun union(
m: Map<K, V>,
merger: (K, V, V) -> V
): TreapMap<K, V>

/**
Produces a new map containing the keys from this map and another map [m].
If a key is present in just one of the maps, the resulting map will contain the key with the corresponding value
from that map. If a key is present in both maps, [merger] is called with the key, the value from this map, and
the value from [m], in that order, and the returned value will appear in the resulting map.
Merge operations are performed in parallel for maps larger than (approximately) 2^parallelThresholdLog2.
*/
public fun parallelUnion(
m: Map<K, V>,
parallelThresholdLog2: Int = 4,
merger: (K, V, V) -> V
): TreapMap<K, V>

/**
Produces a new map containing the keys that are present in both this map and another map [m].
For each key, the resulting map will contain the key with the value returned by [merger], which is called with
the key, the value from this map, and the value from [m], in that order.
*/
public fun intersect(
m: Map<K, V>,
merger: (K, V, V) -> V
): TreapMap<K, V>

/**
Produces a new map containing the keys that are present in both this map and another map [m].
For each key, the resulting map will contain the key with the value returned by [merger], which is called with
the key, the value from this map, and the value from [m], in that order.
Merge operations are performed in parallel for maps larger than (approximately) 2^parallelThresholdLog2.
*/
public fun parallelIntersect(
m: Map<K, V>,
parallelThresholdLog2: Int = 4,
merger: (K, V, V) -> V
): TreapMap<K, V>

/**
Produces a new [TreapMap] with updated entries, by applying supplied [merger] to each entry of this map and
another map [m].
Expand Down
60 changes: 60 additions & 0 deletions collect/src/test/kotlin/com/certora/collect/TreapMapTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,66 @@ abstract class TreapMapTest {
m1.merge(m2, merger3))
}

@Test
fun union() {
assertEquals(testMapOf(), testMapOf().union(testMapOf()) { _, a, _ -> a })
assertEquals(testMapOf(1 to 2), testMapOf(1 to 2).union(testMapOf()) { _, a, _ -> a })
assertEquals(testMapOf(1 to 2), testMapOf().union(testMapOf(1 to 2)) { _, a, _ -> a })
assertEquals(
testMapOf(1 to 2, 2 to 3, 3 to 4),
testMapOf(1 to 2, 2 to 3).union(testMapOf(2 to 3, 3 to 4)) { _, a, _ -> a }
)

val m1 = testMapOf(2 to 2, 3 to 3)
val m2 = testMapOf(3 to 4)
assertEquals(
mapOf(2 to 2, 3 to 3),
m1.union(m2) { _, a, _ -> a }
)
assertEquals(
mapOf(2 to 2, 3 to 4),
m2.union(m1) { _, a, _ -> a }
)
assertEquals(
mapOf(2 to 2, 3 to 4),
m1.union(m2) { _, _, b -> b }
)
assertEquals(
mapOf(2 to 2, 3 to 3),
m2.union(m1) { _, _, b -> b }
)
}

@Test
fun intersect() {
assertEquals(testMapOf(), testMapOf().intersect(testMapOf()) { _, a, _ -> a })
assertEquals(testMapOf(), testMapOf(1 to 2).intersect(testMapOf()) { _, a, _ -> a })
assertEquals(testMapOf(), testMapOf().intersect(testMapOf(1 to 2)) { _, a, _ -> a })
assertEquals(
testMapOf(2 to 3),
testMapOf(1 to 2, 2 to 3).intersect(testMapOf(2 to 3, 3 to 4)) { _, a, _ -> a }
)

val m1 = testMapOf(2 to 2, 3 to 3)
val m2 = testMapOf(3 to 4)
assertEquals(
mapOf(3 to 3),
m1.intersect(m2) { _, a, _ -> a }
)
assertEquals(
mapOf(3 to 4),
m2.intersect(m1) { _, a, _ -> a }
)
assertEquals(
mapOf(3 to 4),
m1.intersect(m2) { _, _, b -> b }
)
assertEquals(
mapOf(3 to 3),
m2.intersect(m1) { _, _, b -> b }
)
}

@Test
fun zip() {
assertEquals(
Expand Down

0 comments on commit 9a78d9a

Please sign in to comment.