Skip to content

Commit

Permalink
Some new performance features (#14)
Browse files Browse the repository at this point in the history
While doing some Prover performance work, I found the following to be
useful:

- Add `mapReduce` and `parallelMapReduce` methods on `TreapSet,`
`TreapMap`, and `TreapList`. These do what you think they do, and are
useful for the obvious reasons.

- Allow `TreapMap.updateValues` (and `parallelUpdateValues`) to change
the type of the values. This makes a straightforward mapping of values
to different types an O(N) operation instead of O(N log N).

I also simplified the `TreapMap.updateEntry` signature a bit (it had
some extraneous nullability annotations).
  • Loading branch information
ericeil authored May 21, 2024
1 parent 04fb5d3 commit 840d5eb
Show file tree
Hide file tree
Showing 14 changed files with 204 additions and 26 deletions.
46 changes: 40 additions & 6 deletions collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT
abstract fun shallowGetValue(key: K): V?

abstract fun shallowRemoveEntry(key: K, value: V): S?
abstract fun <U> shallowUpdate(entryKey: K, toUpdate: U, merger: (V?, U?) -> V?): S?
abstract fun <U> shallowUpdate(entryKey: K, toUpdate: U, merger: (V?, U) -> V?): S?
abstract fun <R : Any> shallowMapReduce(map: (K, V) -> R, reduce: (R, R) -> R): R

/**
Applies a merge function to all entries in this Treap node.
Expand Down Expand Up @@ -212,7 +213,13 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT
/**
Applies a transform to each entry, producing new values.
*/
override fun updateValues(transform: (K, V) -> V?): TreapMap<K, V> = when {
@Suppress("UNCHECKED_CAST", "Treapability")
override fun <R : Any> updateValues(transform: (K, V) -> R?): TreapMap<K, R> =
(this as AbstractTreapMap<Any?, Any?, *>).updateValuesErasedTypes(
transform as (Any?, Any?) -> Any?
) as TreapMap<K, R>

private fun updateValuesErasedTypes(transform: (K, V) -> V?): TreapMap<K, V> = when {
isEmpty() -> self
else -> notForking(this) {
updateValuesImpl(transform) ?: clear()
Expand All @@ -227,7 +234,14 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT
@param[transform] The transform to apply to each entry. Must be pure and thread-safe.
*/
override fun parallelUpdateValues(parallelThresholdLog2: Int, transform: (K, V) -> V?): TreapMap<K, V> = when {
@Suppress("UNCHECKED_CAST", "Treapability")
override fun <R : Any> parallelUpdateValues(parallelThresholdLog2: Int, transform: (K, V) -> R?): TreapMap<K, R> =
(this as AbstractTreapMap<Any?, Any?, *>).parallelUpdateValuesErasedTypes(
parallelThresholdLog2,
transform as (Any?, Any?) -> Any?
) as TreapMap<K, R>

private fun parallelUpdateValuesErasedTypes(parallelThresholdLog2: Int, transform: (K, V) -> V?): TreapMap<K, V> = when {
isEmpty() -> self
else -> maybeForking(self, threshold = { it.isApproximatelySmallerThanLog2(parallelThresholdLog2) }) {
updateValuesImpl(transform) ?: clear()
Expand Down Expand Up @@ -280,7 +294,7 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT
}
```
*/
override fun <U> updateEntry(key: K, value: U?, merger: (V?, U?) -> V?): TreapMap<K, V> {
override fun <U> updateEntry(key: K, value: U, merger: (V?, U) -> V?): TreapMap<K, V> {
return self.updateEntry(key.toTreapKey().precompute(), key, value, merger, ::new) ?: clear()
}

Expand Down Expand Up @@ -332,6 +346,26 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT
private fun shallowZipThisOnly() = shallowEntrySequence().map { MapEntry(it.key, it.value to null) }
private fun shallowZipThatOnly() = shallowEntrySequence().map { MapEntry(it.key, null to it.value) }
protected abstract fun shallowZip(that: S): Sequence<Map.Entry<K, Pair<V?, V?>>>

override fun <R : Any> mapReduce(map: (K, V) -> R, reduce: (R, R) -> R): R =
notForking(self) { mapReduceImpl(map, reduce) }

override fun <R : Any> parallelMapReduce(map: (K, V) -> R, reduce: (R, R) -> R, parallelThresholdLog2: Int): R =
maybeForking(self, threshold = { it.isApproximatelySmallerThanLog2(parallelThresholdLog2) }) {
mapReduceImpl(map, reduce)
}

context(ThresholdForker<S>)
private fun <R : Any> mapReduceImpl(map: (K, V) -> R, reduce: (R, R) -> R): R {
val (left, middle, right) = fork(
self,
{ left?.mapReduceImpl(map, reduce) },
{ shallowMapReduce(map, reduce) },
{ right?.mapReduceImpl(map, reduce) }
)
val leftAndMiddle = left?.let { reduce(it, middle) } ?: middle
return right?.let { reduce(leftAndMiddle, it) } ?: leftAndMiddle
}
}

/**
Expand All @@ -357,8 +391,8 @@ internal fun <@Treapable K, V, @Treapable S : AbstractTreapMap<K, V, S>> S?.remo
internal fun <@Treapable K, V, U, @Treapable S : AbstractTreapMap<K, V, S>> S?.updateEntry(
thatKey: TreapKey<K>,
entryKey: K,
toUpdate: U?,
merger: (V?, U?) -> V?,
toUpdate: U,
merger: (V?, U) -> V?,
new: (K, V) -> S
): S? = when {
this == null -> {
Expand Down
23 changes: 23 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/AbstractTreapSet.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.certora.collect

import com.certora.forkjoin.*

/**
Base class for TreapSet implementations. Provides the Set operations; derived classes deal with type-specific
behavior such as hash collisions. See `Treap` for an overview of all of this.
Expand Down Expand Up @@ -62,6 +64,7 @@ internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet<E, S>>
abstract fun shallowRemoveAll(predicate: (E) -> Boolean): S?
abstract fun shallowContainsAll(elements: S): Boolean
abstract fun shallowContainsAny(elements: S): Boolean
abstract fun <R : Any> shallowMapReduce(map: (E) -> R, reduce: (R, R) -> R): R


////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -152,6 +155,26 @@ internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet<E, S>>
left === null && right === null -> shallowGetSingleElement()
else -> null
}

override fun <R : Any> mapReduce(map: (E) -> R, reduce: (R, R) -> R): R =
notForking(self) { mapReduceImpl(map, reduce) }

override fun <R : Any> parallelMapReduce(map: (E) -> R, reduce: (R, R) -> R, parallelThresholdLog2: Int): R =
maybeForking(self, threshold = { it.isApproximatelySmallerThanLog2(parallelThresholdLog2) }) {
mapReduceImpl(map, reduce)
}

context(ThresholdForker<S>)
private fun <R : Any> mapReduceImpl(map: (E) -> R, reduce: (R, R) -> R): R {
val (left, middle, right) = fork(
self,
{ left?.mapReduceImpl(map, reduce) },
{ shallowMapReduce(map, reduce) },
{ right?.mapReduceImpl(map, reduce) }
)
val leftAndMiddle = left?.let { reduce(it, middle) } ?: middle
return right?.let { reduce(leftAndMiddle, it) } ?: leftAndMiddle
}
}

/**
Expand Down
3 changes: 3 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/EmptyTreapList.kt
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ internal class EmptyTreapList<E> private constructor() : TreapList<E>, java.io.S
else -> throw IndexOutOfBoundsException("Empty list")
}

override fun <R : Any> mapReduce(map: (E) -> R, reduce: (R, R) -> R): R? = null
override fun <R : Any> parallelMapReduce(map: (E) -> R, reduce: (R, R) -> R, parallelThresholdLog2: Int): R? = null

companion object {
private val instance = EmptyTreapList<Nothing>()
@Suppress("UNCHECKED_CAST")
Expand Down
19 changes: 14 additions & 5 deletions collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,19 @@ internal class EmptyTreapMap<@Treapable K, V> private constructor() : TreapMap<K
override fun remove(key: K): TreapMap<K, V> = this
override fun remove(key: K, value: V): TreapMap<K, V> = this

override fun updateValues(transform: (K, V) -> V?): TreapMap<K, V> = this
override fun parallelUpdateValues(parallelThresholdLog2: Int, transform: (K, V) -> V?): TreapMap<K, V> = this
override fun <R : Any> updateValues(
transform: (K, V) -> R?
): TreapMap<K, R> = treapMapOf()

override fun <U> updateEntry(key: K, value: U?, merger: (V?, U?) -> V?): TreapMap<K, V> =
override fun <R : Any> parallelUpdateValues(
parallelThresholdLog2: Int,
transform: (K, V) -> R?
): TreapMap<K, R> = treapMapOf()

override fun <R : Any> mapReduce(map: (K, V) -> R, reduce: (R, R) -> R): R? = null
override fun <R : Any> parallelMapReduce(map: (K, V) -> R, reduce: (R, R) -> R, parallelThresholdLog2: Int): R? = null

override fun <U> updateEntry(key: K, value: U, merger: (V?, U) -> V?): TreapMap<K, V> =
when (val v = merger(null, value)) {
null -> this
else -> put(key, v)
Expand Down Expand Up @@ -53,7 +62,7 @@ internal class EmptyTreapMap<@Treapable K, V> private constructor() : TreapMap<K
@Suppress("Treapability", "UNCHECKED_CAST")
override fun put(key: K, value: V): TreapMap<K, V> = when (key) {
is PrefersHashTreap -> HashTreapMap(key, value)
is Comparable<*> ->
is Comparable<*> ->
SortedTreapMap<Comparable<Comparable<*>>, V>(key as Comparable<Comparable<*>>, value) as TreapMap<K, V>
else -> HashTreapMap(key, value)
}
Expand All @@ -71,4 +80,4 @@ internal class EmptyTreapMap<@Treapable K, V> private constructor() : TreapMap<K
@Suppress("UNCHECKED_CAST")
operator fun <@Treapable K, V> invoke(): EmptyTreapMap<K, V> = instance as EmptyTreapMap<K, V>
}
}
}
2 changes: 2 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/EmptyTreapSet.kt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ internal class EmptyTreapSet<@Treapable E> private constructor() : TreapSet<E>,
override fun retainAll(elements: Collection<E>): TreapSet<E> = this
override fun single(): E = throw NoSuchElementException("Empty set.")
override fun singleOrNull(): E? = null
override fun <R : Any> mapReduce(map: (E) -> R, reduce: (R, R) -> R): R? = null
override fun <R : Any> parallelMapReduce(map: (E) -> R, reduce: (R, R) -> R, parallelThresholdLog2: Int): R? = null

@Suppress("Treapability", "UNCHECKED_CAST")
override fun add(element: E): TreapSet<E> = when (element) {
Expand Down
11 changes: 10 additions & 1 deletion collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ internal class HashTreapMap<@Treapable K, V>(
}
}

override fun <U> shallowUpdate(entryKey: K, toUpdate: U, merger: (V?, U?) -> V?): HashTreapMap<K, V>? {
override fun <U> shallowUpdate(entryKey: K, toUpdate: U, merger: (V?, U) -> V?): HashTreapMap<K, V>? {
return when (this.key) {
entryKey -> {
val newValue = merger(this.value, toUpdate)
Expand Down Expand Up @@ -280,6 +280,15 @@ internal class HashTreapMap<@Treapable K, V>(
forEachPair { (k, v) -> h += AbstractMapEntry.hashCode(k, v) }
return h
}

override fun <R : Any> shallowMapReduce(map: (K, V) -> R, reduce: (R, R) -> R): R {
var result: R? = null
forEachPair {
val mapped = map(it.key, it.value)
result = result?.let { result -> reduce(result, mapped) } ?: mapped
}
return result!!
}
}

internal interface KeyValuePairList<K, V> {
Expand Down
9 changes: 9 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/HashTreapSet.kt
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,15 @@ internal class HashTreapSet<@Treapable E>(
}.iterator()

override fun shallowGetSingleElement(): E? = element.takeIf { next == null }

override fun <R : Any> shallowMapReduce(map: (E) -> R, reduce: (R, R) -> R): R {
var result: R? = null
forEachNodeElement {
val mapped = map(it)
result = result?.let { result -> reduce(result, mapped) } ?: mapped
}
return result!!
}
}

internal interface ElementList<E> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ internal class SortedTreapMap<@Treapable K : Comparable<K>, V>(
}
}

override fun <U> shallowUpdate(entryKey: K, toUpdate: U, merger: (V?, U?) -> V?): SortedTreapMap<K, V>? {
override fun <U> shallowUpdate(entryKey: K, toUpdate: U, merger: (V?, U) -> V?): SortedTreapMap<K, V>? {
val newValue = merger(value, toUpdate)
return when {
newValue == null -> null
Expand Down Expand Up @@ -119,4 +119,6 @@ internal class SortedTreapMap<@Treapable K : Comparable<K>, V>(

fun firstEntry(): Map.Entry<K, V>? = left?.firstEntry() ?: this.asEntry()
fun lastEntry(): Map.Entry<K, V>? = right?.lastEntry() ?: this.asEntry()

override fun <R : Any> shallowMapReduce(map: (K, V) -> R, reduce: (R, R) -> R): R = map(key, value)
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,5 @@ internal class SortedTreapSet<@Treapable E : Comparable<E>>(
override fun shallowComputeHashCode(): Int = treapKey.hashCode()
override fun shallowGetSingleElement(): E = treapKey
override fun shallowForEach(action: (element: E) -> Unit): Unit { action(treapKey) }
override fun <R : Any> shallowMapReduce(map: (E) -> R, reduce: (R, R) -> R): R = map(treapKey)
}
3 changes: 2 additions & 1 deletion collect/src/main/kotlin/com/certora/collect/Treap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,8 @@ internal fun <@Treapable T, S : Treap<T, S>> S?.computeHashCode(): Int = when {
along a single path, under the assumption that the tree is balanced.
*/
internal tailrec fun <@Treapable T, S : Treap<T, S>> S?.isApproximatelySmallerThanLog2(sizeLog2: Int): Boolean = when {
sizeLog2 < 0 -> throw IllegalArgumentException("sizeLog2 must be positive")
this == null -> true
sizeLog2 <= 0 -> false
sizeLog2 == 0 -> false
else -> this.left.isApproximatelySmallerThanLog2(sizeLog2 - 1)
}
3 changes: 3 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/TreapList.kt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ public sealed interface TreapList<E> : PersistentList<E> {
public fun updateElements(transform: (E) -> E?): TreapList<E>
public fun updateElementsIndexed(transform: (Int, E) -> E?): TreapList<E>

public fun <R : Any> mapReduce(map: (E) -> R, reduce: (R, R) -> R): R?
public fun <R : Any> parallelMapReduce(map: (E) -> R, reduce: (R, R) -> R, parallelThresholdLog2: Int = 5): R?

/**
A [PersistentList.Builder] that produces a [TreapList].
*/
Expand Down
29 changes: 29 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/TreapListNode.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.certora.collect

import com.certora.forkjoin.*
import kotlin.random.Random
import java.lang.Math.addExact

Expand Down Expand Up @@ -364,6 +365,27 @@ internal class TreapListNode<E> private constructor(
right?.forEachNodeIndexed(right.rightIndex(thisIndex), action)
}


override fun <R : Any> mapReduce(map: (E) -> R, reduce: (R, R) -> R): R =
notForking(this) { mapReduceImpl(map, reduce) }

override fun <R : Any> parallelMapReduce(map: (E) -> R, reduce: (R, R) -> R, parallelThresholdLog2: Int): R =
maybeForking(this, threshold = { it.isApproximatelySmallerThanLog2(parallelThresholdLog2) }) {
mapReduceImpl(map, reduce)
}

context(ThresholdForker<TreapListNode<E>>)
private fun <R : Any> mapReduceImpl(map: (E) -> R, reduce: (R, R) -> R): R {
val (left, middle, right) = fork(
this,
{ left?.mapReduceImpl(map, reduce) },
{ map(elem) },
{ right?.mapReduceImpl(map, reduce) }
)
val leftAndMiddle = left?.let { reduce(it, middle) } ?: middle
return right?.let { reduce(leftAndMiddle, it) } ?: leftAndMiddle
}

companion object {
private infix fun <E> TreapListNode<E>?.append(that: TreapListNode<E>?): TreapListNode<E>? = when {
this == null -> that
Expand Down Expand Up @@ -430,5 +452,12 @@ internal class TreapListNode<E> private constructor(
// Build the whole list
return buildLowerPri(Int.MAX_VALUE, elems.next(), Random.Default.nextInt()).node
}

internal tailrec fun <E> TreapListNode<E>?.isApproximatelySmallerThanLog2(sizeLog2: Int): Boolean = when {
sizeLog2 < 0 -> throw IllegalArgumentException("sizeLog2 must be positive")
this == null -> true
sizeLog2 == 0 -> false
else -> this.left.isApproximatelySmallerThanLog2(sizeLog2 - 1)
}
}
}
48 changes: 40 additions & 8 deletions collect/src/main/kotlin/com/certora/collect/TreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,56 @@ public sealed interface TreapMap<K, V> : PersistentMap<K, V> {
merger: (K, V?, V?) -> V?
): TreapMap<K, V>

public fun updateValues(
transform: (K, V) -> V?
): TreapMap<K, V>
/**
Produces a new [TreapMap] with updated entries, by applying the supplied [transform]. Removes entries for which
[transform] returns null.
Note that even a seemingly non-mutating transform may result in a different map, if the map contains null values:
```
val map = treapMapOf("a" to null).updateValues { _, v -> v } // yields an empty map
```
*/
public fun <R : Any> updateValues(
transform: (K, V) -> R?
): TreapMap<K, R>

/**
Produces a new [TreapMap] with updated entries, by applying the supplied [transform]. Removes entries for which
[transform] returns null.
Operations are performed in parallel for maps larger than (approximately) 2^parallelThresholdLog2.
public fun parallelUpdateValues(
See additional nodes on [updateValues].
*/
public fun <R : Any> parallelUpdateValues(
parallelThresholdLog2: Int = 5,
transform: (K, V) -> V?
): TreapMap<K, V>
transform: (K, V) -> R?
): TreapMap<K, R>

public fun <U> updateEntry(
key: K,
value: U?,
merger: (V?, U?) -> V?
value: U,
merger: (V?, U) -> V?
): TreapMap<K, V>

public fun zip(
m: Map<out K, V>
): Sequence<Map.Entry<K, Pair<V?, V?>>>

/**
Applies the [map] function to each entry, then applies [reduce] to the results, in a depth-first traversal of
the underlying tree. Returns null if the map is empty.
*/
public fun <R : Any> mapReduce(map: (K, V) -> R, reduce: (R, R) -> R): R?

/**
Applies the [map] function to each entry, then applies [reduce] to the results, in a depth-first traversal of
the underlying tree. Returns null if the map is empty.
Operations are performed in parallel for maps larger than (approximately) 2^parallelThresholdLog2.
*/
public fun <R : Any> parallelMapReduce(map: (K, V) -> R, reduce: (R, R) -> R, parallelThresholdLog2: Int = 5): R?
}

public fun <@Treapable K, V> treapMapOf(): TreapMap<K, V> = EmptyTreapMap<K, V>()
Expand Down
Loading

0 comments on commit 840d5eb

Please sign in to comment.