diff --git a/collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt b/collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt index 78795ea..5a9e90f 100644 --- a/collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt @@ -77,7 +77,8 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT abstract fun shallowGetValue(key: K): V? abstract fun shallowRemoveEntry(key: K, value: V): S? - abstract fun shallowUpdate(entryKey: K, toUpdate: U, merger: (V?, U?) -> V?): S? + abstract fun shallowUpdate(entryKey: K, toUpdate: U, merger: (V?, U) -> V?): S? + abstract fun shallowMapReduce(map: (K, V) -> R, reduce: (R, R) -> R): R /** Applies a merge function to all entries in this Treap node. @@ -212,7 +213,13 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT /** Applies a transform to each entry, producing new values. */ - override fun updateValues(transform: (K, V) -> V?): TreapMap = when { + @Suppress("UNCHECKED_CAST", "Treapability") + override fun updateValues(transform: (K, V) -> R?): TreapMap = + (this as AbstractTreapMap).updateValuesErasedTypes( + transform as (Any?, Any?) -> Any? + ) as TreapMap + + private fun updateValuesErasedTypes(transform: (K, V) -> V?): TreapMap = when { isEmpty() -> self else -> notForking(this) { updateValuesImpl(transform) ?: clear() @@ -227,7 +234,14 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT @param[transform] The transform to apply to each entry. Must be pure and thread-safe. */ - override fun parallelUpdateValues(parallelThresholdLog2: Int, transform: (K, V) -> V?): TreapMap = when { + @Suppress("UNCHECKED_CAST", "Treapability") + override fun parallelUpdateValues(parallelThresholdLog2: Int, transform: (K, V) -> R?): TreapMap = + (this as AbstractTreapMap).parallelUpdateValuesErasedTypes( + parallelThresholdLog2, + transform as (Any?, Any?) -> Any? + ) as TreapMap + + private fun parallelUpdateValuesErasedTypes(parallelThresholdLog2: Int, transform: (K, V) -> V?): TreapMap = when { isEmpty() -> self else -> maybeForking(self, threshold = { it.isApproximatelySmallerThanLog2(parallelThresholdLog2) }) { updateValuesImpl(transform) ?: clear() @@ -280,7 +294,7 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT } ``` */ - override fun updateEntry(key: K, value: U?, merger: (V?, U?) -> V?): TreapMap { + override fun updateEntry(key: K, value: U, merger: (V?, U) -> V?): TreapMap { return self.updateEntry(key.toTreapKey().precompute(), key, value, merger, ::new) ?: clear() } @@ -332,6 +346,26 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT private fun shallowZipThisOnly() = shallowEntrySequence().map { MapEntry(it.key, it.value to null) } private fun shallowZipThatOnly() = shallowEntrySequence().map { MapEntry(it.key, null to it.value) } protected abstract fun shallowZip(that: S): Sequence>> + + override fun mapReduce(map: (K, V) -> R, reduce: (R, R) -> R): R = + notForking(self) { mapReduceImpl(map, reduce) } + + override fun parallelMapReduce(map: (K, V) -> R, reduce: (R, R) -> R, parallelThresholdLog2: Int): R = + maybeForking(self, threshold = { it.isApproximatelySmallerThanLog2(parallelThresholdLog2) }) { + mapReduceImpl(map, reduce) + } + + context(ThresholdForker) + private fun mapReduceImpl(map: (K, V) -> R, reduce: (R, R) -> R): R { + val (left, middle, right) = fork( + self, + { left?.mapReduceImpl(map, reduce) }, + { shallowMapReduce(map, reduce) }, + { right?.mapReduceImpl(map, reduce) } + ) + val leftAndMiddle = left?.let { reduce(it, middle) } ?: middle + return right?.let { reduce(leftAndMiddle, it) } ?: leftAndMiddle + } } /** @@ -357,8 +391,8 @@ internal fun <@Treapable K, V, @Treapable S : AbstractTreapMap> S?.remo internal fun <@Treapable K, V, U, @Treapable S : AbstractTreapMap> S?.updateEntry( thatKey: TreapKey, entryKey: K, - toUpdate: U?, - merger: (V?, U?) -> V?, + toUpdate: U, + merger: (V?, U) -> V?, new: (K, V) -> S ): S? = when { this == null -> { diff --git a/collect/src/main/kotlin/com/certora/collect/AbstractTreapSet.kt b/collect/src/main/kotlin/com/certora/collect/AbstractTreapSet.kt index a80e390..52fe41d 100644 --- a/collect/src/main/kotlin/com/certora/collect/AbstractTreapSet.kt +++ b/collect/src/main/kotlin/com/certora/collect/AbstractTreapSet.kt @@ -1,5 +1,7 @@ package com.certora.collect +import com.certora.forkjoin.* + /** Base class for TreapSet implementations. Provides the Set operations; derived classes deal with type-specific behavior such as hash collisions. See `Treap` for an overview of all of this. @@ -62,6 +64,7 @@ internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet> abstract fun shallowRemoveAll(predicate: (E) -> Boolean): S? abstract fun shallowContainsAll(elements: S): Boolean abstract fun shallowContainsAny(elements: S): Boolean + abstract fun shallowMapReduce(map: (E) -> R, reduce: (R, R) -> R): R //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -152,6 +155,26 @@ internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet> left === null && right === null -> shallowGetSingleElement() else -> null } + + override fun mapReduce(map: (E) -> R, reduce: (R, R) -> R): R = + notForking(self) { mapReduceImpl(map, reduce) } + + override fun parallelMapReduce(map: (E) -> R, reduce: (R, R) -> R, parallelThresholdLog2: Int): R = + maybeForking(self, threshold = { it.isApproximatelySmallerThanLog2(parallelThresholdLog2) }) { + mapReduceImpl(map, reduce) + } + + context(ThresholdForker) + private fun mapReduceImpl(map: (E) -> R, reduce: (R, R) -> R): R { + val (left, middle, right) = fork( + self, + { left?.mapReduceImpl(map, reduce) }, + { shallowMapReduce(map, reduce) }, + { right?.mapReduceImpl(map, reduce) } + ) + val leftAndMiddle = left?.let { reduce(it, middle) } ?: middle + return right?.let { reduce(leftAndMiddle, it) } ?: leftAndMiddle + } } /** diff --git a/collect/src/main/kotlin/com/certora/collect/EmptyTreapList.kt b/collect/src/main/kotlin/com/certora/collect/EmptyTreapList.kt index d686b45..619fa14 100644 --- a/collect/src/main/kotlin/com/certora/collect/EmptyTreapList.kt +++ b/collect/src/main/kotlin/com/certora/collect/EmptyTreapList.kt @@ -58,6 +58,9 @@ internal class EmptyTreapList private constructor() : TreapList, java.io.S else -> throw IndexOutOfBoundsException("Empty list") } + override fun mapReduce(map: (E) -> R, reduce: (R, R) -> R): R? = null + override fun parallelMapReduce(map: (E) -> R, reduce: (R, R) -> R, parallelThresholdLog2: Int): R? = null + companion object { private val instance = EmptyTreapList() @Suppress("UNCHECKED_CAST") diff --git a/collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt b/collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt index 3192f0a..5975ae1 100644 --- a/collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt @@ -19,10 +19,19 @@ internal class EmptyTreapMap<@Treapable K, V> private constructor() : TreapMap = this override fun remove(key: K, value: V): TreapMap = this - override fun updateValues(transform: (K, V) -> V?): TreapMap = this - override fun parallelUpdateValues(parallelThresholdLog2: Int, transform: (K, V) -> V?): TreapMap = this + override fun updateValues( + transform: (K, V) -> R? + ): TreapMap = treapMapOf() - override fun updateEntry(key: K, value: U?, merger: (V?, U?) -> V?): TreapMap = + override fun parallelUpdateValues( + parallelThresholdLog2: Int, + transform: (K, V) -> R? + ): TreapMap = treapMapOf() + + override fun mapReduce(map: (K, V) -> R, reduce: (R, R) -> R): R? = null + override fun parallelMapReduce(map: (K, V) -> R, reduce: (R, R) -> R, parallelThresholdLog2: Int): R? = null + + override fun updateEntry(key: K, value: U, merger: (V?, U) -> V?): TreapMap = when (val v = merger(null, value)) { null -> this else -> put(key, v) @@ -53,7 +62,7 @@ internal class EmptyTreapMap<@Treapable K, V> private constructor() : TreapMap = when (key) { is PrefersHashTreap -> HashTreapMap(key, value) - is Comparable<*> -> + is Comparable<*> -> SortedTreapMap>, V>(key as Comparable>, value) as TreapMap else -> HashTreapMap(key, value) } @@ -71,4 +80,4 @@ internal class EmptyTreapMap<@Treapable K, V> private constructor() : TreapMap invoke(): EmptyTreapMap = instance as EmptyTreapMap } -} \ No newline at end of file +} diff --git a/collect/src/main/kotlin/com/certora/collect/EmptyTreapSet.kt b/collect/src/main/kotlin/com/certora/collect/EmptyTreapSet.kt index a8a2082..3628343 100644 --- a/collect/src/main/kotlin/com/certora/collect/EmptyTreapSet.kt +++ b/collect/src/main/kotlin/com/certora/collect/EmptyTreapSet.kt @@ -24,6 +24,8 @@ internal class EmptyTreapSet<@Treapable E> private constructor() : TreapSet, override fun retainAll(elements: Collection): TreapSet = this override fun single(): E = throw NoSuchElementException("Empty set.") override fun singleOrNull(): E? = null + override fun mapReduce(map: (E) -> R, reduce: (R, R) -> R): R? = null + override fun parallelMapReduce(map: (E) -> R, reduce: (R, R) -> R, parallelThresholdLog2: Int): R? = null @Suppress("Treapability", "UNCHECKED_CAST") override fun add(element: E): TreapSet = when (element) { diff --git a/collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt b/collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt index 4490ec1..5666af8 100644 --- a/collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt @@ -165,7 +165,7 @@ internal class HashTreapMap<@Treapable K, V>( } } - override fun shallowUpdate(entryKey: K, toUpdate: U, merger: (V?, U?) -> V?): HashTreapMap? { + override fun shallowUpdate(entryKey: K, toUpdate: U, merger: (V?, U) -> V?): HashTreapMap? { return when (this.key) { entryKey -> { val newValue = merger(this.value, toUpdate) @@ -280,6 +280,15 @@ internal class HashTreapMap<@Treapable K, V>( forEachPair { (k, v) -> h += AbstractMapEntry.hashCode(k, v) } return h } + + override fun shallowMapReduce(map: (K, V) -> R, reduce: (R, R) -> R): R { + var result: R? = null + forEachPair { + val mapped = map(it.key, it.value) + result = result?.let { result -> reduce(result, mapped) } ?: mapped + } + return result!! + } } internal interface KeyValuePairList { diff --git a/collect/src/main/kotlin/com/certora/collect/HashTreapSet.kt b/collect/src/main/kotlin/com/certora/collect/HashTreapSet.kt index ca06e08..8346603 100644 --- a/collect/src/main/kotlin/com/certora/collect/HashTreapSet.kt +++ b/collect/src/main/kotlin/com/certora/collect/HashTreapSet.kt @@ -229,6 +229,15 @@ internal class HashTreapSet<@Treapable E>( }.iterator() override fun shallowGetSingleElement(): E? = element.takeIf { next == null } + + override fun shallowMapReduce(map: (E) -> R, reduce: (R, R) -> R): R { + var result: R? = null + forEachNodeElement { + val mapped = map(it) + result = result?.let { result -> reduce(result, mapped) } ?: mapped + } + return result!! + } } internal interface ElementList { diff --git a/collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt b/collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt index b17002c..e447172 100644 --- a/collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt @@ -74,7 +74,7 @@ internal class SortedTreapMap<@Treapable K : Comparable, V>( } } - override fun shallowUpdate(entryKey: K, toUpdate: U, merger: (V?, U?) -> V?): SortedTreapMap? { + override fun shallowUpdate(entryKey: K, toUpdate: U, merger: (V?, U) -> V?): SortedTreapMap? { val newValue = merger(value, toUpdate) return when { newValue == null -> null @@ -119,4 +119,6 @@ internal class SortedTreapMap<@Treapable K : Comparable, V>( fun firstEntry(): Map.Entry? = left?.firstEntry() ?: this.asEntry() fun lastEntry(): Map.Entry? = right?.lastEntry() ?: this.asEntry() + + override fun shallowMapReduce(map: (K, V) -> R, reduce: (R, R) -> R): R = map(key, value) } diff --git a/collect/src/main/kotlin/com/certora/collect/SortedTreapSet.kt b/collect/src/main/kotlin/com/certora/collect/SortedTreapSet.kt index 9a13093..40f31c9 100644 --- a/collect/src/main/kotlin/com/certora/collect/SortedTreapSet.kt +++ b/collect/src/main/kotlin/com/certora/collect/SortedTreapSet.kt @@ -49,4 +49,5 @@ internal class SortedTreapSet<@Treapable E : Comparable>( override fun shallowComputeHashCode(): Int = treapKey.hashCode() override fun shallowGetSingleElement(): E = treapKey override fun shallowForEach(action: (element: E) -> Unit): Unit { action(treapKey) } + override fun shallowMapReduce(map: (E) -> R, reduce: (R, R) -> R): R = map(treapKey) } diff --git a/collect/src/main/kotlin/com/certora/collect/Treap.kt b/collect/src/main/kotlin/com/certora/collect/Treap.kt index fd5fb08..6580046 100644 --- a/collect/src/main/kotlin/com/certora/collect/Treap.kt +++ b/collect/src/main/kotlin/com/certora/collect/Treap.kt @@ -265,7 +265,8 @@ internal fun <@Treapable T, S : Treap> S?.computeHashCode(): Int = when { along a single path, under the assumption that the tree is balanced. */ internal tailrec fun <@Treapable T, S : Treap> S?.isApproximatelySmallerThanLog2(sizeLog2: Int): Boolean = when { + sizeLog2 < 0 -> throw IllegalArgumentException("sizeLog2 must be positive") this == null -> true - sizeLog2 <= 0 -> false + sizeLog2 == 0 -> false else -> this.left.isApproximatelySmallerThanLog2(sizeLog2 - 1) } diff --git a/collect/src/main/kotlin/com/certora/collect/TreapList.kt b/collect/src/main/kotlin/com/certora/collect/TreapList.kt index 38d235d..fd56f87 100644 --- a/collect/src/main/kotlin/com/certora/collect/TreapList.kt +++ b/collect/src/main/kotlin/com/certora/collect/TreapList.kt @@ -37,6 +37,9 @@ public sealed interface TreapList : PersistentList { public fun updateElements(transform: (E) -> E?): TreapList public fun updateElementsIndexed(transform: (Int, E) -> E?): TreapList + public fun mapReduce(map: (E) -> R, reduce: (R, R) -> R): R? + public fun parallelMapReduce(map: (E) -> R, reduce: (R, R) -> R, parallelThresholdLog2: Int = 5): R? + /** A [PersistentList.Builder] that produces a [TreapList]. */ diff --git a/collect/src/main/kotlin/com/certora/collect/TreapListNode.kt b/collect/src/main/kotlin/com/certora/collect/TreapListNode.kt index 8f6c01f..163b1d9 100644 --- a/collect/src/main/kotlin/com/certora/collect/TreapListNode.kt +++ b/collect/src/main/kotlin/com/certora/collect/TreapListNode.kt @@ -1,5 +1,6 @@ package com.certora.collect +import com.certora.forkjoin.* import kotlin.random.Random import java.lang.Math.addExact @@ -364,6 +365,27 @@ internal class TreapListNode private constructor( right?.forEachNodeIndexed(right.rightIndex(thisIndex), action) } + + override fun mapReduce(map: (E) -> R, reduce: (R, R) -> R): R = + notForking(this) { mapReduceImpl(map, reduce) } + + override fun parallelMapReduce(map: (E) -> R, reduce: (R, R) -> R, parallelThresholdLog2: Int): R = + maybeForking(this, threshold = { it.isApproximatelySmallerThanLog2(parallelThresholdLog2) }) { + mapReduceImpl(map, reduce) + } + + context(ThresholdForker>) + private fun mapReduceImpl(map: (E) -> R, reduce: (R, R) -> R): R { + val (left, middle, right) = fork( + this, + { left?.mapReduceImpl(map, reduce) }, + { map(elem) }, + { right?.mapReduceImpl(map, reduce) } + ) + val leftAndMiddle = left?.let { reduce(it, middle) } ?: middle + return right?.let { reduce(leftAndMiddle, it) } ?: leftAndMiddle + } + companion object { private infix fun TreapListNode?.append(that: TreapListNode?): TreapListNode? = when { this == null -> that @@ -430,5 +452,12 @@ internal class TreapListNode private constructor( // Build the whole list return buildLowerPri(Int.MAX_VALUE, elems.next(), Random.Default.nextInt()).node } + + internal tailrec fun TreapListNode?.isApproximatelySmallerThanLog2(sizeLog2: Int): Boolean = when { + sizeLog2 < 0 -> throw IllegalArgumentException("sizeLog2 must be positive") + this == null -> true + sizeLog2 == 0 -> false + else -> this.left.isApproximatelySmallerThanLog2(sizeLog2 - 1) + } } } diff --git a/collect/src/main/kotlin/com/certora/collect/TreapMap.kt b/collect/src/main/kotlin/com/certora/collect/TreapMap.kt index ef538cf..c9adc55 100644 --- a/collect/src/main/kotlin/com/certora/collect/TreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/TreapMap.kt @@ -34,24 +34,56 @@ public sealed interface TreapMap : PersistentMap { merger: (K, V?, V?) -> V? ): TreapMap - public fun updateValues( - transform: (K, V) -> V? - ): TreapMap + /** + Produces a new [TreapMap] with updated entries, by applying the supplied [transform]. Removes entries for which + [transform] returns null. + + Note that even a seemingly non-mutating transform may result in a different map, if the map contains null values: + + ``` + val map = treapMapOf("a" to null).updateValues { _, v -> v } // yields an empty map + ``` + */ + public fun updateValues( + transform: (K, V) -> R? + ): TreapMap + + /** + Produces a new [TreapMap] with updated entries, by applying the supplied [transform]. Removes entries for which + [transform] returns null. + + Operations are performed in parallel for maps larger than (approximately) 2^parallelThresholdLog2. - public fun parallelUpdateValues( + See additional nodes on [updateValues]. + */ + public fun parallelUpdateValues( parallelThresholdLog2: Int = 5, - transform: (K, V) -> V? - ): TreapMap + transform: (K, V) -> R? + ): TreapMap public fun updateEntry( key: K, - value: U?, - merger: (V?, U?) -> V? + value: U, + merger: (V?, U) -> V? ): TreapMap public fun zip( m: Map ): Sequence>> + + /** + Applies the [map] function to each entry, then applies [reduce] to the results, in a depth-first traversal of + the underlying tree. Returns null if the map is empty. + */ + public fun mapReduce(map: (K, V) -> R, reduce: (R, R) -> R): R? + + /** + Applies the [map] function to each entry, then applies [reduce] to the results, in a depth-first traversal of + the underlying tree. Returns null if the map is empty. + + Operations are performed in parallel for maps larger than (approximately) 2^parallelThresholdLog2. + */ + public fun parallelMapReduce(map: (K, V) -> R, reduce: (R, R) -> R, parallelThresholdLog2: Int = 5): R? } public fun <@Treapable K, V> treapMapOf(): TreapMap = EmptyTreapMap() diff --git a/collect/src/main/kotlin/com/certora/collect/TreapSet.kt b/collect/src/main/kotlin/com/certora/collect/TreapSet.kt index 81994c6..14e719a 100644 --- a/collect/src/main/kotlin/com/certora/collect/TreapSet.kt +++ b/collect/src/main/kotlin/com/certora/collect/TreapSet.kt @@ -26,8 +26,9 @@ public sealed interface TreapSet : PersistentSet { override fun builder(): Builder<@UnsafeVariance T> = TreapSetBuilder(this) /** - Checks if this set contains any of the given [elements]. This is equivalent to, but more efficient than, - `this.intersect(elements).isNotEmpty()`. + Checks if this set contains any of the given [elements]. + + This is equivalent to, but more efficient than, `this.intersect(elements).isNotEmpty()`. */ public fun containsAny(elements: Iterable<@UnsafeVariance T>): Boolean @@ -43,15 +44,35 @@ public sealed interface TreapSet : PersistentSet { /** If this set contains an element that compares equal to the specified [element], returns that element instance. + This is useful for implementing intern tables, for example. */ public fun findEqual(element: @UnsafeVariance T): T? /** - Calls [action] for each element in the set. This traverses the treap without allocating temporary storage, - which may be more efficient than [forEach]. + Calls [action] for each element in the set. + + This traverses the treap without allocating temporary storage, which may be more efficient than [forEach]. */ public fun forEachElement(action: (element: T) -> Unit): Unit + + /** + Calls [map] for each element in the set, and then reduces the results with [reduce]. + + This traverses the treap without allocating temporary storage, which may be more efficient than using the [map] + and [reduce] functions. + */ + public fun mapReduce(map: (T) -> R, reduce: (R, R) -> R): R? + + /** + Calls [map] for each element in the set, and then reduces the results with [reduce]. + + Operations are performed in parallel for sets larger than (approximately) 2^parallelThresholdLog2. + + This traverses the treap without allocating temporary storage, which may be more efficient than using the [map] + and [reduce] functions. + */ + public fun parallelMapReduce(map: (T) -> R, reduce: (R, R) -> R, parallelThresholdLog2: Int = 5): R? } /**