Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some new performance features #14

Merged
merged 3 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 40 additions & 6 deletions collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT
abstract fun shallowGetValue(key: K): V?

abstract fun shallowRemoveEntry(key: K, value: V): S?
abstract fun <U> shallowUpdate(entryKey: K, toUpdate: U, merger: (V?, U?) -> V?): S?
abstract fun <U> shallowUpdate(entryKey: K, toUpdate: U, merger: (V?, U) -> V?): S?
abstract fun <R : Any> shallowMapReduce(map: (K, V) -> R, reduce: (R, R) -> R): R

/**
Applies a merge function to all entries in this Treap node.
Expand Down Expand Up @@ -212,7 +213,13 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT
/**
Applies a transform to each entry, producing new values.
*/
override fun updateValues(transform: (K, V) -> V?): TreapMap<K, V> = when {
@Suppress("UNCHECKED_CAST", "Treapability")
override fun <R : Any> updateValues(transform: (K, V) -> R?): TreapMap<K, R> =
(this as AbstractTreapMap<Any?, Any?, *>).updateValuesErasedTypes(
transform as (Any?, Any?) -> Any?
) as TreapMap<K, R>

private fun updateValuesErasedTypes(transform: (K, V) -> V?): TreapMap<K, V> = when {
isEmpty() -> self
else -> notForking(this) {
updateValuesImpl(transform) ?: clear()
Expand All @@ -227,7 +234,14 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT

@param[transform] The transform to apply to each entry. Must be pure and thread-safe.
*/
override fun parallelUpdateValues(parallelThresholdLog2: Int, transform: (K, V) -> V?): TreapMap<K, V> = when {
@Suppress("UNCHECKED_CAST", "Treapability")
override fun <R : Any> parallelUpdateValues(parallelThresholdLog2: Int, transform: (K, V) -> R?): TreapMap<K, R> =
(this as AbstractTreapMap<Any?, Any?, *>).parallelUpdateValuesErasedTypes(
parallelThresholdLog2,
transform as (Any?, Any?) -> Any?
) as TreapMap<K, R>

private fun parallelUpdateValuesErasedTypes(parallelThresholdLog2: Int, transform: (K, V) -> V?): TreapMap<K, V> = when {
isEmpty() -> self
else -> maybeForking(self, threshold = { it.isApproximatelySmallerThanLog2(parallelThresholdLog2) }) {
updateValuesImpl(transform) ?: clear()
Expand Down Expand Up @@ -280,7 +294,7 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT
}
```
*/
override fun <U> updateEntry(key: K, value: U?, merger: (V?, U?) -> V?): TreapMap<K, V> {
override fun <U> updateEntry(key: K, value: U, merger: (V?, U) -> V?): TreapMap<K, V> {
return self.updateEntry(key.toTreapKey().precompute(), key, value, merger, ::new) ?: clear()
}

Expand Down Expand Up @@ -332,6 +346,26 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT
private fun shallowZipThisOnly() = shallowEntrySequence().map { MapEntry(it.key, it.value to null) }
private fun shallowZipThatOnly() = shallowEntrySequence().map { MapEntry(it.key, null to it.value) }
protected abstract fun shallowZip(that: S): Sequence<Map.Entry<K, Pair<V?, V?>>>

override fun <R : Any> mapReduce(map: (K, V) -> R, reduce: (R, R) -> R): R =
notForking(self) { mapReduceImpl(map, reduce) }

override fun <R : Any> parallelMapReduce(map: (K, V) -> R, reduce: (R, R) -> R, parallelThresholdLog2: Int): R =
maybeForking(self, threshold = { it.isApproximatelySmallerThanLog2(parallelThresholdLog2) }) {
mapReduceImpl(map, reduce)
}

context(ThresholdForker<S>)
private fun <R : Any> mapReduceImpl(map: (K, V) -> R, reduce: (R, R) -> R): R {
val (left, middle, right) = fork(
self,
{ left?.mapReduceImpl(map, reduce) },
{ shallowMapReduce(map, reduce) },
{ right?.mapReduceImpl(map, reduce) }
)
val leftAndMiddle = left?.let { reduce(it, middle) } ?: middle
return right?.let { reduce(leftAndMiddle, it) } ?: leftAndMiddle
}
Comment on lines +359 to +368
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a compelling reason to not just go ahead and make this an erased types implementation too? It seems like we always go back and do this for performance anyway so...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Static typing is good. :)

}

/**
Expand All @@ -357,8 +391,8 @@ internal fun <@Treapable K, V, @Treapable S : AbstractTreapMap<K, V, S>> S?.remo
internal fun <@Treapable K, V, U, @Treapable S : AbstractTreapMap<K, V, S>> S?.updateEntry(
thatKey: TreapKey<K>,
entryKey: K,
toUpdate: U?,
merger: (V?, U?) -> V?,
toUpdate: U,
merger: (V?, U) -> V?,
new: (K, V) -> S
): S? = when {
this == null -> {
Expand Down
23 changes: 23 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/AbstractTreapSet.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.certora.collect

import com.certora.forkjoin.*

/**
Base class for TreapSet implementations. Provides the Set operations; derived classes deal with type-specific
behavior such as hash collisions. See `Treap` for an overview of all of this.
Expand Down Expand Up @@ -62,6 +64,7 @@ internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet<E, S>>
abstract fun shallowRemoveAll(predicate: (E) -> Boolean): S?
abstract fun shallowContainsAll(elements: S): Boolean
abstract fun shallowContainsAny(elements: S): Boolean
abstract fun <R : Any> shallowMapReduce(map: (E) -> R, reduce: (R, R) -> R): R


////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -152,6 +155,26 @@ internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet<E, S>>
left === null && right === null -> shallowGetSingleElement()
else -> null
}

override fun <R : Any> mapReduce(map: (E) -> R, reduce: (R, R) -> R): R =
notForking(self) { mapReduceImpl(map, reduce) }

override fun <R : Any> parallelMapReduce(map: (E) -> R, reduce: (R, R) -> R, parallelThresholdLog2: Int): R =
maybeForking(self, threshold = { it.isApproximatelySmallerThanLog2(parallelThresholdLog2) }) {
mapReduceImpl(map, reduce)
}

context(ThresholdForker<S>)
private fun <R : Any> mapReduceImpl(map: (E) -> R, reduce: (R, R) -> R): R {
val (left, middle, right) = fork(
self,
{ left?.mapReduceImpl(map, reduce) },
{ shallowMapReduce(map, reduce) },
{ right?.mapReduceImpl(map, reduce) }
)
val leftAndMiddle = left?.let { reduce(it, middle) } ?: middle
return right?.let { reduce(leftAndMiddle, it) } ?: leftAndMiddle
}
Comment on lines +159 to +177
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just so I'm not crazy: this is the same implementation as the map case right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think it's just the map lambda type that's different.

}

/**
Expand Down
3 changes: 3 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/EmptyTreapList.kt
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ internal class EmptyTreapList<E> private constructor() : TreapList<E>, java.io.S
else -> throw IndexOutOfBoundsException("Empty list")
}

override fun <R : Any> mapReduce(map: (E) -> R, reduce: (R, R) -> R): R? = null
override fun <R : Any> parallelMapReduce(map: (E) -> R, reduce: (R, R) -> R, parallelThresholdLog2: Int): R? = null

companion object {
private val instance = EmptyTreapList<Nothing>()
@Suppress("UNCHECKED_CAST")
Expand Down
19 changes: 14 additions & 5 deletions collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,19 @@ internal class EmptyTreapMap<@Treapable K, V> private constructor() : TreapMap<K
override fun remove(key: K): TreapMap<K, V> = this
override fun remove(key: K, value: V): TreapMap<K, V> = this

override fun updateValues(transform: (K, V) -> V?): TreapMap<K, V> = this
override fun parallelUpdateValues(parallelThresholdLog2: Int, transform: (K, V) -> V?): TreapMap<K, V> = this
override fun <R : Any> updateValues(
transform: (K, V) -> R?
): TreapMap<K, R> = treapMapOf()

override fun <U> updateEntry(key: K, value: U?, merger: (V?, U?) -> V?): TreapMap<K, V> =
override fun <R : Any> parallelUpdateValues(
parallelThresholdLog2: Int,
transform: (K, V) -> R?
): TreapMap<K, R> = treapMapOf()
Comment on lines +22 to +29
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't you return this and just unsafe cast it? That's ultimately what treapMapOf is doing right? I'm not suggesting we do this, just making sure I understand what the code is doing.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is just a nicer-looking way of doing that.


override fun <R : Any> mapReduce(map: (K, V) -> R, reduce: (R, R) -> R): R? = null
override fun <R : Any> parallelMapReduce(map: (K, V) -> R, reduce: (R, R) -> R, parallelThresholdLog2: Int): R? = null

override fun <U> updateEntry(key: K, value: U, merger: (V?, U) -> V?): TreapMap<K, V> =
when (val v = merger(null, value)) {
null -> this
else -> put(key, v)
Expand Down Expand Up @@ -53,7 +62,7 @@ internal class EmptyTreapMap<@Treapable K, V> private constructor() : TreapMap<K
@Suppress("Treapability", "UNCHECKED_CAST")
override fun put(key: K, value: V): TreapMap<K, V> = when (key) {
is PrefersHashTreap -> HashTreapMap(key, value)
is Comparable<*> ->
is Comparable<*> ->
SortedTreapMap<Comparable<Comparable<*>>, V>(key as Comparable<Comparable<*>>, value) as TreapMap<K, V>
else -> HashTreapMap(key, value)
}
Expand All @@ -71,4 +80,4 @@ internal class EmptyTreapMap<@Treapable K, V> private constructor() : TreapMap<K
@Suppress("UNCHECKED_CAST")
operator fun <@Treapable K, V> invoke(): EmptyTreapMap<K, V> = instance as EmptyTreapMap<K, V>
}
}
}
2 changes: 2 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/EmptyTreapSet.kt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ internal class EmptyTreapSet<@Treapable E> private constructor() : TreapSet<E>,
override fun retainAll(elements: Collection<E>): TreapSet<E> = this
override fun single(): E = throw NoSuchElementException("Empty set.")
override fun singleOrNull(): E? = null
override fun <R : Any> mapReduce(map: (E) -> R, reduce: (R, R) -> R): R? = null
override fun <R : Any> parallelMapReduce(map: (E) -> R, reduce: (R, R) -> R, parallelThresholdLog2: Int): R? = null

@Suppress("Treapability", "UNCHECKED_CAST")
override fun add(element: E): TreapSet<E> = when (element) {
Expand Down
11 changes: 10 additions & 1 deletion collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ internal class HashTreapMap<@Treapable K, V>(
}
}

override fun <U> shallowUpdate(entryKey: K, toUpdate: U, merger: (V?, U?) -> V?): HashTreapMap<K, V>? {
override fun <U> shallowUpdate(entryKey: K, toUpdate: U, merger: (V?, U) -> V?): HashTreapMap<K, V>? {
return when (this.key) {
entryKey -> {
val newValue = merger(this.value, toUpdate)
Expand Down Expand Up @@ -280,6 +280,15 @@ internal class HashTreapMap<@Treapable K, V>(
forEachPair { (k, v) -> h += AbstractMapEntry.hashCode(k, v) }
return h
}

override fun <R : Any> shallowMapReduce(map: (K, V) -> R, reduce: (R, R) -> R): R {
var result: R? = null
forEachPair {
val mapped = map(it.key, it.value)
result = result?.let { result -> reduce(result, mapped) } ?: mapped
}
return result!!
}
Comment on lines +284 to +291
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the type bound on R here just so we have nullability of result? I'm trying to think of a way to relax this, but I think you can only use, e.g., lateinit on non-null types.

Probably not worth it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We depend on this constraint in mapReduceImpl, where it's a bigger deal. It avoids needing to allocate a result holder for each step of the traversal.

}

internal interface KeyValuePairList<K, V> {
Expand Down
9 changes: 9 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/HashTreapSet.kt
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,15 @@ internal class HashTreapSet<@Treapable E>(
}.iterator()

override fun shallowGetSingleElement(): E? = element.takeIf { next == null }

override fun <R : Any> shallowMapReduce(map: (E) -> R, reduce: (R, R) -> R): R {
var result: R? = null
forEachNodeElement {
val mapped = map(it)
result = result?.let { result -> reduce(result, mapped) } ?: mapped
}
return result!!
}
}

internal interface ElementList<E> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ internal class SortedTreapMap<@Treapable K : Comparable<K>, V>(
}
}

override fun <U> shallowUpdate(entryKey: K, toUpdate: U, merger: (V?, U?) -> V?): SortedTreapMap<K, V>? {
override fun <U> shallowUpdate(entryKey: K, toUpdate: U, merger: (V?, U) -> V?): SortedTreapMap<K, V>? {
val newValue = merger(value, toUpdate)
return when {
newValue == null -> null
Expand Down Expand Up @@ -119,4 +119,6 @@ internal class SortedTreapMap<@Treapable K : Comparable<K>, V>(

fun firstEntry(): Map.Entry<K, V>? = left?.firstEntry() ?: this.asEntry()
fun lastEntry(): Map.Entry<K, V>? = right?.lastEntry() ?: this.asEntry()

override fun <R : Any> shallowMapReduce(map: (K, V) -> R, reduce: (R, R) -> R): R = map(key, value)
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,5 @@ internal class SortedTreapSet<@Treapable E : Comparable<E>>(
override fun shallowComputeHashCode(): Int = treapKey.hashCode()
override fun shallowGetSingleElement(): E = treapKey
override fun shallowForEach(action: (element: E) -> Unit): Unit { action(treapKey) }
override fun <R : Any> shallowMapReduce(map: (E) -> R, reduce: (R, R) -> R): R = map(treapKey)
}
3 changes: 2 additions & 1 deletion collect/src/main/kotlin/com/certora/collect/Treap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,8 @@ internal fun <@Treapable T, S : Treap<T, S>> S?.computeHashCode(): Int = when {
along a single path, under the assumption that the tree is balanced.
*/
internal tailrec fun <@Treapable T, S : Treap<T, S>> S?.isApproximatelySmallerThanLog2(sizeLog2: Int): Boolean = when {
sizeLog2 < 0 -> throw IllegalArgumentException("sizeLog2 must be positive")
this == null -> true
sizeLog2 <= 0 -> false
sizeLog2 == 0 -> false
else -> this.left.isApproximatelySmallerThanLog2(sizeLog2 - 1)
}
3 changes: 3 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/TreapList.kt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ public sealed interface TreapList<E> : PersistentList<E> {
public fun updateElements(transform: (E) -> E?): TreapList<E>
public fun updateElementsIndexed(transform: (Int, E) -> E?): TreapList<E>

public fun <R : Any> mapReduce(map: (E) -> R, reduce: (R, R) -> R): R?
public fun <R : Any> parallelMapReduce(map: (E) -> R, reduce: (R, R) -> R, parallelThresholdLog2: Int = 5): R?

/**
A [PersistentList.Builder] that produces a [TreapList].
*/
Expand Down
29 changes: 29 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/TreapListNode.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.certora.collect

import com.certora.forkjoin.*
import kotlin.random.Random
import java.lang.Math.addExact

Expand Down Expand Up @@ -364,6 +365,27 @@ internal class TreapListNode<E> private constructor(
right?.forEachNodeIndexed(right.rightIndex(thisIndex), action)
}


override fun <R : Any> mapReduce(map: (E) -> R, reduce: (R, R) -> R): R =
notForking(this) { mapReduceImpl(map, reduce) }

override fun <R : Any> parallelMapReduce(map: (E) -> R, reduce: (R, R) -> R, parallelThresholdLog2: Int): R =
maybeForking(this, threshold = { it.isApproximatelySmallerThanLog2(parallelThresholdLog2) }) {
mapReduceImpl(map, reduce)
}

context(ThresholdForker<TreapListNode<E>>)
private fun <R : Any> mapReduceImpl(map: (E) -> R, reduce: (R, R) -> R): R {
val (left, middle, right) = fork(
this,
{ left?.mapReduceImpl(map, reduce) },
{ map(elem) },
{ right?.mapReduceImpl(map, reduce) }
)
val leftAndMiddle = left?.let { reduce(it, middle) } ?: middle
return right?.let { reduce(leftAndMiddle, it) } ?: leftAndMiddle
}

companion object {
private infix fun <E> TreapListNode<E>?.append(that: TreapListNode<E>?): TreapListNode<E>? = when {
this == null -> that
Expand Down Expand Up @@ -430,5 +452,12 @@ internal class TreapListNode<E> private constructor(
// Build the whole list
return buildLowerPri(Int.MAX_VALUE, elems.next(), Random.Default.nextInt()).node
}

internal tailrec fun <E> TreapListNode<E>?.isApproximatelySmallerThanLog2(sizeLog2: Int): Boolean = when {
sizeLog2 < 0 -> throw IllegalArgumentException("sizeLog2 must be positive")
this == null -> true
sizeLog2 == 0 -> false
else -> this.left.isApproximatelySmallerThanLog2(sizeLog2 - 1)
}
}
}
48 changes: 40 additions & 8 deletions collect/src/main/kotlin/com/certora/collect/TreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,56 @@ public sealed interface TreapMap<K, V> : PersistentMap<K, V> {
merger: (K, V?, V?) -> V?
): TreapMap<K, V>

public fun updateValues(
transform: (K, V) -> V?
): TreapMap<K, V>
/**
Produces a new [TreapMap] with updated entries, by applying the supplied [transform]. Removes entries for which
[transform] returns null.

Note that even a seemingly non-mutating transform may result in a different map, if the map contains null values:

```
val map = treapMapOf("a" to null).updateValues { _, v -> v } // yields an empty map
```
*/
public fun <R : Any> updateValues(
transform: (K, V) -> R?
): TreapMap<K, R>
Comment on lines +47 to +49
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is a change in the public API right? If V was a nullable type, you could use this function to update the values of the map. But now you can't, updatedValues is unusable for nullable types. Maybe that was a bug, as I suspect that we removed entries for which transform returned null, but still, it means part of the API is unusable depending on your type parameter.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'm favor of this change, as it makes explicit that null values are filtered out, where as previously you could write:

 treapMapOf("foo" to 3, "bar" to null).updateValues { _, v -> v }

and have this return a map without "bar" (and no null keys) but this wasn't reflected in the return type of updateValues

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we document this behavior a bit more explicitly though?

Copy link
Collaborator Author

@ericeil ericeil May 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can definitely write this, before and after this change:

var map = treapMapOf("foo" to 3, "bar" to null).updateValues { _, v -> v }

In both cases the result is a map without the entries that had null values. With this change, the result will be typed as TreapMap<String, Int> instead of TreapMap<String, Int?>.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add some comments.


/**
Produces a new [TreapMap] with updated entries, by applying the supplied [transform]. Removes entries for which
[transform] returns null.

Operations are performed in parallel for maps larger than (approximately) 2^parallelThresholdLog2.

public fun parallelUpdateValues(
See additional nodes on [updateValues].
*/
public fun <R : Any> parallelUpdateValues(
parallelThresholdLog2: Int = 5,
transform: (K, V) -> V?
): TreapMap<K, V>
transform: (K, V) -> R?
): TreapMap<K, R>

public fun <U> updateEntry(
key: K,
value: U?,
merger: (V?, U?) -> V?
value: U,
merger: (V?, U) -> V?
): TreapMap<K, V>

public fun zip(
m: Map<out K, V>
): Sequence<Map.Entry<K, Pair<V?, V?>>>

/**
Applies the [map] function to each entry, then applies [reduce] to the results, in a depth-first traversal of
the underlying tree. Returns null if the map is empty.
*/
public fun <R : Any> mapReduce(map: (K, V) -> R, reduce: (R, R) -> R): R?

/**
Applies the [map] function to each entry, then applies [reduce] to the results, in a depth-first traversal of
the underlying tree. Returns null if the map is empty.

Operations are performed in parallel for maps larger than (approximately) 2^parallelThresholdLog2.
*/
public fun <R : Any> parallelMapReduce(map: (K, V) -> R, reduce: (R, R) -> R, parallelThresholdLog2: Int = 5): R?
}

public fun <@Treapable K, V> treapMapOf(): TreapMap<K, V> = EmptyTreapMap<K, V>()
Expand Down
Loading
Loading