Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TreapMap union and intersection #19

Merged
merged 3 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 152 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT
Applies a merge function to all entries in this Treap node.
*/
abstract fun getShallowMerger(merger: (K, V?, V?) -> V?): (S?, S?) -> S?
abstract fun getShallowUnionMerger(merger: (K, V, V) -> V): (S, S) -> S
abstract fun getShallowIntersectMerger(merger: (K, V, V) -> V): (S, S) -> S?

private fun containsEntry(entry: Map.Entry<K, V>): Boolean {
val key = entry.key
Expand Down Expand Up @@ -152,6 +154,52 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT
override operator fun iterator() = entrySequence().map { it.value }.iterator()
}

override fun union(m: Map<K, V>, merger: (K, V, V) -> V): TreapMap<K, V> =
m.useAsTreap(
{ otherTreap -> self.unionWith(otherTreap, getShallowUnionMerger(merger)) ?: clear() },
{ fallbackUnion(m, merger) }
)

override fun parallelUnion(m: Map<K, V>, parallelThresholdLog2: Int, merger: (K, V, V) -> V): TreapMap<K, V> =
m.useAsTreap(
{ otherTreap -> self.parallelUnionWith(otherTreap, parallelThresholdLog2, getShallowUnionMerger(merger)) ?: clear() },
{ fallbackUnion(m, merger) }
)

private fun fallbackUnion(m: Map<K, V>, merger: (K, V, V) -> V): TreapMap<K, V> {
var newThis = this as TreapMap<K, V>
for ((k, v) in m.entries) {
if (k in this) {
newThis = newThis + (k to merger(k, this[k]!!, v))
} else {
newThis = newThis + (k to v)
}
}
return newThis
}

override fun intersect(m: Map<K, V>, merger: (K, V, V) -> V): TreapMap<K, V> =
m.useAsTreap(
{ otherTreap -> self.intersectWith(otherTreap, getShallowIntersectMerger(merger)) ?: clear() },
{ fallbackIntersect(m, merger) }
)

override fun parallelIntersect(m: Map<K, V>, parallelThresholdLog2: Int, merger: (K, V, V) -> V): TreapMap<K, V> =
m.useAsTreap(
{ otherTreap -> self.parallelIntersectWith(otherTreap, parallelThresholdLog2, getShallowIntersectMerger(merger)) ?: clear() },
{ fallbackIntersect(m, merger) }
)

private fun fallbackIntersect(m: Map<K, V>, merger: (K, V, V) -> V): TreapMap<K, V> {
var newThis = clear()
for ((k, v) in m.entries) {
if (k in this) {
newThis = newThis + (k to merger(k, this[k]!!, v))
}
}
return newThis
}

/**
Merges the entries in `m` with the entries in this AbstractTreapMap, applying the "merger" function to get the
new values for each key.
Expand Down Expand Up @@ -491,3 +539,107 @@ private fun <@Treapable K, V, @Treapable S : AbstractTreapMap<K, V, S>> S?.merge
return newThis?.with(newLeft, newRight) ?: (newLeft join newRight)
}

internal fun <@Treapable K, V, @Treapable S : AbstractTreapMap<K, V, S>> S?.unionWith(
that: S?,
shallowUnion: (S, S) -> S
): S? =
notForking(this to that) {
unionWithImpl(that, shallowUnion)
}

internal fun <@Treapable K, V, @Treapable S : AbstractTreapMap<K, V, S>> S?.parallelUnionWith(
that: S?,
parallelThresholdLog2: Int,
shallowUnion: (S, S) -> S
): S? =
maybeForking(
this to that,
{
it.first.isApproximatelySmallerThanLog2(parallelThresholdLog2 - 1) &&
it.second.isApproximatelySmallerThanLog2(parallelThresholdLog2 - 1)
}
) {
unionWithImpl(that, shallowUnion)
}

context(ThresholdForker<Pair<S?, S?>>)
private fun <@Treapable K, V, @Treapable S : AbstractTreapMap<K, V, S>> S?.unionWithImpl(
that: S?,
shallowUnion: (S, S) -> S
): S? {
val (newLeft, newRight, newThis) = when {
this == null -> return that
that == null -> return this
this.comparePriorityTo(that) >= 0 -> {
val thatSplit = that.split(this)
fork(
this to that,
{ this.left.unionWithImpl(thatSplit.left, shallowUnion) },
{ this.right.unionWithImpl(thatSplit.right, shallowUnion) },
{ thatSplit.duplicate?.let { shallowUnion(this, it) } ?: this }
)
}
else -> {
val thisSplit = this.split(that)
fork(
this to that,
{ thisSplit.left.unionWithImpl(that.left, shallowUnion) },
{ thisSplit.right.unionWithImpl(that.right, shallowUnion) },
{ thisSplit.duplicate?.let { shallowUnion(it, that) } ?: that }
)
}
}
return newThis.with(newLeft, newRight)
}

internal fun <@Treapable K, V, @Treapable S : AbstractTreapMap<K, V, S>> S?.intersectWith(
that: S?,
shallowIntersect: (S, S) -> S?
): S? =
notForking(this to that) {
intersectWithImpl(that, shallowIntersect)
}

internal fun <@Treapable K, V, @Treapable S : AbstractTreapMap<K, V, S>> S?.parallelIntersectWith(
that: S?,
parallelThresholdLog2: Int,
shallowIntersect: (S, S) -> S?
): S? =
maybeForking(
this to that,
{
it.first.isApproximatelySmallerThanLog2(parallelThresholdLog2 - 1) &&
it.second.isApproximatelySmallerThanLog2(parallelThresholdLog2 - 1)
}
) {
intersectWithImpl(that, shallowIntersect)
}

context(ThresholdForker<Pair<S?, S?>>)
private fun <@Treapable K, V, @Treapable S : AbstractTreapMap<K, V, S>> S?.intersectWithImpl(
that: S?,
shallowIntersect: (S, S) -> S?
): S? {
val (newLeft, newRight, newThis) = when {
this == null || that == null -> return null
this.comparePriorityTo(that) >= 0 -> {
val thatSplit = that.split(this)
fork(
this to that,
{ this.left.intersectWithImpl(thatSplit.left, shallowIntersect) },
{ this.right.intersectWithImpl(thatSplit.right, shallowIntersect) },
{ thatSplit.duplicate?.let { shallowIntersect(this, it) } }
)
}
else -> {
val thisSplit = this.split(that)
fork(
this to that,
{ thisSplit.left.intersectWithImpl(that.left, shallowIntersect) },
{ thisSplit.right.intersectWithImpl(that.right, shallowIntersect) },
{ thisSplit.duplicate?.let { shallowIntersect(it, that) } }
)
}
}
return newThis?.with(newLeft, newRight) ?: (newLeft join newRight)
}
6 changes: 6 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ internal class EmptyTreapMap<@Treapable K, V> private constructor() : TreapMap<K
else -> put(key, v)
}

override fun union(m: Map<K, V>, merger: (K, V, V) -> V): TreapMap<K, V> = putAll(m)
override fun parallelUnion(m: Map<K, V>, parallelThresholdLog2: Int, merger: (K, V, V) -> V): TreapMap<K, V> = putAll(m)

override fun intersect(m: Map<K, V>, merger: (K, V, V) -> V): TreapMap<K, V> = this
override fun parallelIntersect(m: Map<K, V>, parallelThresholdLog2: Int, merger: (K, V, V) -> V): TreapMap<K, V> = this

override fun merge(m: Map<K, V>, merger: (K, V?, V?) -> V?): TreapMap<K, V> {
var map: TreapMap<K, V> = this
for ((key, value) in m) {
Expand Down
42 changes: 42 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,48 @@ internal class HashTreapMap<@Treapable K, V>(
}
}

override fun getShallowUnionMerger(
merger: (K, V, V) -> V
): (HashTreapMap<K, V>, HashTreapMap<K, V>) -> HashTreapMap<K, V> = { t1, t2 ->
var newPairs: KeyValuePairList.More<K, V>? = null
t1.forEachPair { (k, v1) ->
val v = if (t2.shallowContainsKey(k)) {
@Suppress("UNCHECKED_CAST")
merger(k, v1, t2.shallowGetValue(k) as V)
} else {
v1
}
newPairs = KeyValuePairList.More(k, v, newPairs)
}
t2.forEachPair { (k, v2) ->
if (!t1.shallowContainsKey(k)) {
newPairs = KeyValuePairList.More(k, v2, newPairs)
}
}
newPairs!!.let { firstPair ->
val newNode = HashTreapMap(firstPair.key, firstPair.value, firstPair.next, t1.left, t1.right)
if (newNode.shallowEquals(t1)) { t1 } else { newNode }
}
}

override fun getShallowIntersectMerger(
merger: (K, V, V) -> V
): (HashTreapMap<K, V>, HashTreapMap<K, V>) -> HashTreapMap<K, V>? = { t1, t2 ->
var newPairs: KeyValuePairList.More<K, V>? = null
t1.forEachPair { (k, v1) ->
if (t2.shallowContainsKey(k)) {
@Suppress("UNCHECKED_CAST")
val v2 = t2.shallowGetValue(k) as V
val v = merger(k, v1, v2)
newPairs = KeyValuePairList.More(k, v, newPairs)
}
}
newPairs?.let { firstPair ->
val newNode = HashTreapMap(firstPair.key, firstPair.value, firstPair.next, t1.left, t1.right)
if (newNode.shallowEquals(t1)) { t1 } else { newNode }
}
}

private inline fun KeyValuePairList<K, V>?.forEachPair(action: (KeyValuePairList<K, V>) -> Unit) {
var current = this
while (current != null) {
Expand Down
14 changes: 14 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,20 @@ internal class SortedTreapMap<@Treapable K, V>(

override fun arbitraryOrNull(): Map.Entry<K, V>? = MapEntry(key, value)

override fun getShallowUnionMerger(
merger: (K, V, V) -> V
): (SortedTreapMap<K, V>, SortedTreapMap<K, V>) -> SortedTreapMap<K, V> = { t1, t2 ->
val v = merger(t1.key, t1.value, t2.value)
SortedTreapMap(t1.key, v, t1.left, t1.right)
}

override fun getShallowIntersectMerger(
merger: (K, V, V) -> V
): (SortedTreapMap<K, V>, SortedTreapMap<K, V>) -> SortedTreapMap<K, V>? = { t1, t2 ->
val v = merger(t1.key, t1.value, t2.value)
SortedTreapMap(t1.key, v, t1.left, t1.right)
}

override fun getShallowMerger(merger: (K, V?, V?) -> V?): (SortedTreapMap<K, V>?, SortedTreapMap<K, V>?) -> SortedTreapMap<K, V>? = { t1, t2 ->
val k = (t1 ?: t2)!!.key
val v1 = t1?.value
Expand Down
4 changes: 3 additions & 1 deletion collect/src/main/kotlin/com/certora/collect/Treap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,9 @@ internal fun <@Treapable T, S : Treap<T, S>> Treap<T, S>?.split(key: TreapKey<T>
}
}
}
internal class Split<@Treapable T, S : Treap<T, S>>(var left: S?, var right: S?, var duplicate: S?)
internal class Split<@Treapable T, S : Treap<T, S>>(var left: S?, var right: S?, var duplicate: S?) {
override fun toString(): String = "Split(left=$left, right=$right, duplicate=$duplicate)"
}


/**
Expand Down
90 changes: 90 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/TreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,92 @@ public sealed interface TreapMap<K, V> : PersistentMap<K, V> {
*/
public fun arbitraryOrNull(): Map.Entry<K, V>?

/**
Calls [action] for each entry in the map.

Traverses the treap without allocating temprarory storage, which may be more efficient than `entries.forEach`.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"temporary"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks

*/
public fun forEachEntry(action: (Map.Entry<K, V>) -> Unit): Unit

/**
Produces a new map containing the keys from this map and another map [m].

If a key is present in just one of the maps, the resulting map will contain the key with the corresponding value
from that map. If a key is present in both maps, [merger] is called with the key, the value from this map, and
the value from [m], in that order, and the returned value will appear in the resulting map.
*/
public fun union(
m: Map<K, V>,
merger: (K, V, V) -> V
): TreapMap<K, V>

/**
Produces a new map containing the keys from this map and another map [m].

If a key is present in just one of the maps, the resulting map will contain the key with the corresponding value
from that map. If a key is present in both maps, [merger] is called with the key, the value from this map, and
the value from [m], in that order, and the returned value will appear in the resulting map.

Merge operations are performed in parallel for maps larger than (approximately) 2^parallelThresholdLog2.
*/
public fun parallelUnion(
m: Map<K, V>,
parallelThresholdLog2: Int = 4,
merger: (K, V, V) -> V
): TreapMap<K, V>

/**
Produces a new map containing the keys that are present in both this map and another map [m].

For each key, the resulting map will contain the key with the value returned by [merger], which is called with
the key, the value from this map, and the value from [m], in that order.
*/
public fun intersect(
m: Map<K, V>,
merger: (K, V, V) -> V
): TreapMap<K, V>

/**
Produces a new map containing the keys that are present in both this map and another map [m].

For each key, the resulting map will contain the key with the value returned by [merger], which is called with
the key, the value from this map, and the value from [m], in that order.

Merge operations are performed in parallel for maps larger than (approximately) 2^parallelThresholdLog2.
*/
public fun parallelIntersect(
m: Map<K, V>,
parallelThresholdLog2: Int = 4,
merger: (K, V, V) -> V
): TreapMap<K, V>

/**
Produces a new [TreapMap] with updated entries, by applying supplied [merger] to each entry of this map and
another map [m].

The [merger] function is called for each key that is present in either map, with the key, the value from this
map, and the value from [m], in that order, as arguments. If the key is not present in one of the maps, the
corresponding [merger] argument will be `null`.

If the [merger] function returns null, the key is not added to the resulting map.
*/
public fun merge(
m: Map<K, V>,
merger: (K, V?, V?) -> V?
): TreapMap<K, V>

/**
Produces a new [TreapMap] with updated entries, by applying supplied [merger] to each entry of this map and
another map [m].

The [merger] function is called for each key that is present in either map, with the key, the value from this
map, and the value from [m], in that order, as arguments. If the key is not present in one of the maps, the
corresponding [merger] argument will be `null`.

If the [merger] function returns null, the key is not added to the resulting map.

Merge operations are performed in parallel for maps larger than (approximately) 2^parallelThresholdLog2.
*/
public fun parallelMerge(
m: Map<K, V>,
parallelThresholdLog2: Int = 4,
Expand Down Expand Up @@ -68,12 +147,23 @@ public sealed interface TreapMap<K, V> : PersistentMap<K, V> {
transform: (K, V) -> R?
): TreapMap<K, R>

/**
Produces a new [TreapMap] with the entry for the specified [key] updated via [merger].

[merger] is called with the current value for the key (or null if the key is absent), and supplied [value]
argument. If the [merger] function returns null, the key will be absent from the resulting map. Otherwise
the resulting map will contain the key with the value returned by the [merger] function.
*/
public fun <U> updateEntry(
key: K,
value: U,
merger: (V?, U) -> V?
): TreapMap<K, V>

/**
Produces a sequence from the entries of this map and another map. For each key, the result is an entry mapping
the key to a pair of values. Each value may be null, if the key is not present in the corresponding map.
*/
public fun zip(
m: Map<out K, V>
): Sequence<Map.Entry<K, Pair<V?, V?>>>
Expand Down
Loading
Loading