Skip to content

Commit

Permalink
Add new methods for querying sorted treap maps (#13)
Browse files Browse the repository at this point in the history
Adds the following methods to `TreapMap`, with inspiration from the
JDK's `TreeMap`:

`firstKey`
`firstEntry`
`floorKey`
`floorEntry`
`lowerKey`
`lowerEntry`

`lastKey`
`lastEntry`
`ceilingKey`
`ceilingEntry`
`higherKey`
`higherEntry`

These are implemented as extension methods so that they can be
constrained to only key types that implement `Comparable`. To allow this
to be done safely, I've also changed `TreapMap`, _et al_, to be
`sealed`, as they should have been all along.
  • Loading branch information
ericeil authored Feb 17, 2024
1 parent aac4001 commit 04fb5d3
Show file tree
Hide file tree
Showing 7 changed files with 219 additions and 28 deletions.
24 changes: 12 additions & 12 deletions collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import kotlinx.collections.immutable.ImmutableSet
Base class for TreapMap implementations. Provides the Map operations; derived classes deal with type-specific
behavior such as hash collisions. See [Treap] for an overview of all of this.
*/
internal abstract class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractTreapMap<K, V, S>>(
internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractTreapMap<K, V, S>>(
left: S?,
right: S?
) : TreapMap<K, V>, Treap<K, S>(left, right) {
Expand All @@ -28,7 +28,7 @@ internal abstract class AbstractTreapMap<@Treapable K, V, @Treapable S : Abstrac
/**
Converts the given Map to a AbstractTreapMap of the same type as 'this'. May copy the map.
*/
fun Map<out K, V>.toTreapMapIfNotEmpty(): AbstractTreapMap<K, V, S>? =
fun Map<out K, V>.toTreapMapIfNotEmpty(): AbstractTreapMap<K, V, S>? =
toTreapMapOrNull() ?: when {
isEmpty() -> null
else -> {
Expand Down Expand Up @@ -338,8 +338,8 @@ internal abstract class AbstractTreapMap<@Treapable K, V, @Treapable S : Abstrac
Removes a map entry (`entryKey`, `entryValue`) with key `key`.
*/
internal fun <@Treapable K, V, @Treapable S : AbstractTreapMap<K, V, S>> S?.removeEntry(
key: TreapKey<K>,
entryKey: K,
key: TreapKey<K>,
entryKey: K,
entryValue: V
): S? = when {
this == null -> null
Expand All @@ -355,10 +355,10 @@ internal fun <@Treapable K, V, @Treapable S : AbstractTreapMap<K, V, S>> S?.remo
}

internal fun <@Treapable K, V, U, @Treapable S : AbstractTreapMap<K, V, S>> S?.updateEntry(
thatKey: TreapKey<K>,
entryKey: K,
toUpdate: U?,
merger: (V?, U?) -> V?,
thatKey: TreapKey<K>,
entryKey: K,
toUpdate: U?,
merger: (V?, U?) -> V?,
new: (K, V) -> S
): S? = when {
this == null -> {
Expand Down Expand Up @@ -397,16 +397,16 @@ internal fun <@Treapable K, V, U, @Treapable S : AbstractTreapMap<K, V, S>> S?.u
'this' over 'that', to preserve the object identity invariant described in the `Treap` summary.
*/
internal fun <@Treapable K, V, @Treapable S : AbstractTreapMap<K, V, S>> S?.mergeWith(
that: S?,
that: S?,
shallowMerge: (S?, S?) -> S?
): S? =
notForking(this to that) {
mergeWithImpl(that, shallowMerge)
}

internal fun <@Treapable K, V, @Treapable S : AbstractTreapMap<K, V, S>> S?.parallelMergeWith(
that: S?,
parallelThresholdLog2: Int,
that: S?,
parallelThresholdLog2: Int,
shallowMerge: (S?, S?) -> S?
): S? =
maybeForking(
Expand All @@ -421,7 +421,7 @@ internal fun <@Treapable K, V, @Treapable S : AbstractTreapMap<K, V, S>> S?.para

context(ThresholdForker<Pair<S?, S?>>)
private fun <@Treapable K, V, @Treapable S : AbstractTreapMap<K, V, S>> S?.mergeWithImpl(
that: S?,
that: S?,
shallowMerge: (S?, S?) -> S?
): S? {
val (newLeft, newRight, newThis) = when {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ package com.certora.collect
Base class for TreapSet implementations. Provides the Set operations; derived classes deal with type-specific
behavior such as hash collisions. See `Treap` for an overview of all of this.
*/
internal abstract class AbstractTreapSet<@Treapable E, S : AbstractTreapSet<E, S>>(
left: S?,
internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet<E, S>>(
left: S?,
right: S?
) : TreapSet<E>, Treap<E, S>(left, right) {
/**
Expand Down Expand Up @@ -144,7 +144,7 @@ internal abstract class AbstractTreapSet<@Treapable E, S : AbstractTreapSet<E, S

override fun forEachElement(action: (element: E) -> Unit): Unit {
left?.forEachElement(action)
shallowForEach(action)
shallowForEach(action)
right?.forEachElement(action)
}

Expand Down
39 changes: 39 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ internal class SortedTreapMap<@Treapable K : Comparable<K>, V>(
override val self get() = this
override val treapKey get() = key

fun asEntry(): Map.Entry<K, V> = MapEntry(key, value)

override fun shallowEntrySequence(): Sequence<Map.Entry<K, V>> = sequenceOf(MapEntry(key, value))

override fun shallowContainsKey(key: K) = true
Expand Down Expand Up @@ -80,4 +82,41 @@ internal class SortedTreapMap<@Treapable K : Comparable<K>, V>(
else -> SortedTreapMap(key, newValue, left, right)
}
}

fun floorEntry(key: K): Map.Entry<K, V>? {
val cmp = key.compareTo(this.key)
return when {
cmp < 0 -> left?.floorEntry(key)
cmp > 0 -> right?.floorEntry(key) ?: this.asEntry()
else -> this.asEntry()
}
}

fun ceilingEntry(key: K): Map.Entry<K, V>? {
val cmp = key.compareTo(this.key)
return when {
cmp < 0 -> left?.ceilingEntry(key) ?: this.asEntry()
cmp > 0 -> right?.ceilingEntry(key)
else -> this.asEntry()
}
}

fun lowerEntry(key: K): Map.Entry<K, V>? {
val cmp = key.compareTo(this.key)
return when {
cmp > 0 -> right?.lowerEntry(key) ?: this.asEntry()
else -> left?.lowerEntry(key)
}
}

fun higherEntry(key: K): Map.Entry<K, V>? {
val cmp = key.compareTo(this.key)
return when {
cmp < 0 -> left?.higherEntry(key) ?: this.asEntry()
else -> right?.higherEntry(key)
}
}

fun firstEntry(): Map.Entry<K, V>? = left?.firstEntry() ?: this.asEntry()
fun lastEntry(): Map.Entry<K, V>? = right?.lastEntry() ?: this.asEntry()
}
2 changes: 1 addition & 1 deletion collect/src/main/kotlin/com/certora/collect/TreapList.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import kotlinx.collections.immutable.PersistentList
A PersistentList implemented as a [Treap](https://en.wikipedia.org/wiki/Treap).
*/
@Treapable
public interface TreapList<E> : PersistentList<E> {
public sealed interface TreapList<E> : PersistentList<E> {
override fun add(element: E): TreapList<E> = addLast(element)
override fun addAll(elements: Collection<E>): TreapList<E>
override fun remove(element: E): TreapList<E>
Expand Down
118 changes: 107 additions & 11 deletions collect/src/main/kotlin/com/certora/collect/TreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@ package com.certora.collect
import kotlinx.collections.immutable.PersistentMap

/**
A [PersistentMap] implemented as a [Treap](https://en.wikipedia.org/wiki/Treap) - a kind of balanced binary tree.
A [PersistentMap] implemented as a [Treap](https://en.wikipedia.org/wiki/Treap) - a kind of balanced binary tree.
*/
@Treapable
public interface TreapMap<K, V> : PersistentMap<K, V> {
public sealed interface TreapMap<K, V> : PersistentMap<K, V> {
override fun put(key: K, value: @UnsafeVariance V): TreapMap<K, V>
override fun remove(key: K): TreapMap<K, V>
override fun remove(key: K, value: @UnsafeVariance V): TreapMap<K, V>
override fun putAll(m: Map<out K, @UnsafeVariance V>): TreapMap<K, V>
override fun clear(): TreapMap<K, V>

/**
A [PersistentMap.Builder] that produces a [TreapMap].
A [PersistentMap.Builder] that produces a [TreapMap].
*/
public interface Builder<K, V>: PersistentMap.Builder<K, V> {
override fun build(): TreapMap<K, V>
Expand All @@ -24,13 +24,13 @@ public interface TreapMap<K, V> : PersistentMap<K, V> {
override fun builder(): Builder<K, @UnsafeVariance V> = TreapMapBuilder(this)

public fun merge(
m: Map<K, V>,
m: Map<K, V>,
merger: (K, V?, V?) -> V?
): TreapMap<K, V>

public fun parallelMerge(
m: Map<K, V>,
parallelThresholdLog2: Int = 4,
m: Map<K, V>,
parallelThresholdLog2: Int = 4,
merger: (K, V?, V?) -> V?
): TreapMap<K, V>

Expand All @@ -39,13 +39,13 @@ public interface TreapMap<K, V> : PersistentMap<K, V> {
): TreapMap<K, V>

public fun parallelUpdateValues(
parallelThresholdLog2: Int = 5,
parallelThresholdLog2: Int = 5,
transform: (K, V) -> V?
): TreapMap<K, V>
): TreapMap<K, V>

public fun <U> updateEntry(
key: K,
value: U?,
key: K,
value: U?,
merger: (V?, U?) -> V?
): TreapMap<K, V>

Expand Down Expand Up @@ -133,3 +133,99 @@ public inline fun <K, V, @Treapable R> TreapMap<out K, V>.mapKeys(transform: (Ma

public inline fun <@Treapable K, V, R> TreapMap<out K, V>.mapValues(transform: (Map.Entry<K, V>) -> R): TreapMap<K, R> =
mapValuesTo(treapMapOf<K, R>().builder(), transform).build()

/**
Returns a key-value mapping associated with the greatest key less than or equal to the given key, or null if there
is no such key.
*/
public fun <@Treapable K : Comparable<K>, V> TreapMap<K, V>.floorEntry(key: K): Map.Entry<K, V>? = when (this) {
is EmptyTreapMap<K, V> -> null
is SortedTreapMap<K, V> -> floorEntry(key)
// Shouldn't happen due to static Comparable constraint on K
is HashTreapMap<K, V> -> throw UnsupportedOperationException("floorEntry is not supported for hashed treap maps")
}

/**
Returns the greatest key less than or equal to the given key, or null if there is no such key.
*/
public fun <@Treapable K : Comparable<K>, V> TreapMap<K, V>.floorKey(key: K): K? = floorEntry(key)?.key

/**
Returns a key-value mapping associated with the least key greater than or equal to the given key, or null if there
is no such key.
*/
public fun <@Treapable K : Comparable<K>, V> TreapMap<K, V>.ceilingEntry(key: K): Map.Entry<K, V>? = when (this) {
is EmptyTreapMap<K, V> -> null
is SortedTreapMap<K, V> -> ceilingEntry(key)
// Shouldn't happen due to static Comparable constraint on K
is HashTreapMap<K, V> -> throw UnsupportedOperationException("ceilingEntry is not supported for hashed treap maps")
}

/**
Returns the least key greater than or equal to the given key, or null if there is no such key.
*/
public fun <@Treapable K : Comparable<K>, V> TreapMap<K, V>.ceilingKey(key: K): K? = ceilingEntry(key)?.key

/**
Returns a key-value mapping associated with the greatest key strictly less than the given key, or null if there is
no such key.
*/
public fun <@Treapable K : Comparable<K>, V> TreapMap<K, V>.lowerEntry(key: K): Map.Entry<K, V>? = when (this) {
is EmptyTreapMap<K, V> -> null
is SortedTreapMap<K, V> -> lowerEntry(key)
// Shouldn't happen due to static Comparable constraint on K
is HashTreapMap<K, V> -> throw UnsupportedOperationException("lowerEntry is not supported for hashed treap maps")
}

/**
Returns the greatest key strictly less than the given key, or null if there is no such key.
*/
public fun <@Treapable K : Comparable<K>, V> TreapMap<K, V>.lowerKey(key: K): K? = lowerEntry(key)?.key


/**
Returns a key-value mapping associated with the least key strictly greater than the given key, or null if there is no
such key.
*/
public fun <@Treapable K : Comparable<K>, V> TreapMap<K, V>.higherEntry(key: K): Map.Entry<K, V>? = when (this) {
is EmptyTreapMap<K, V> -> null
is SortedTreapMap<K, V> -> higherEntry(key)
// Shouldn't happen due to static Comparable constraint on K
is HashTreapMap<K, V> -> throw UnsupportedOperationException("higherEntry is not supported for hashed treap maps")
}

/**
Returns the least key strictly greater than the given key, or null if there is no such key.
*/
public fun <@Treapable K : Comparable<K>, V> TreapMap<K, V>.higherKey(key: K): K? = higherEntry(key)?.key

/**
Returns a key-value mapping associated with the least key in this map, or null if the map is empty.
*/
public fun <@Treapable K : Comparable<K>, V> TreapMap<K, V>.firstEntry(): Map.Entry<K, V>? = when (this) {
is EmptyTreapMap<K, V> -> null
is SortedTreapMap<K, V> -> firstEntry()
// Shouldn't happen due to static Comparable constraint on K
is HashTreapMap<K, V> -> throw UnsupportedOperationException("firstEntry is not supported for hashed treap maps")
}

/**
Returns the least key in this map, or null if the map is empty.
*/
public fun <@Treapable K : Comparable<K>, V> TreapMap<K, V>.firstKey(): K? = firstEntry()?.key


/**
Returns a key-value mapping associated with the greatest key in this map, or null if the map is empty.
*/
public fun <@Treapable K : Comparable<K>, V> TreapMap<K, V>.lastEntry(): Map.Entry<K, V>? = when (this) {
is EmptyTreapMap<K, V> -> null
is SortedTreapMap<K, V> -> lastEntry()
// Shouldn't happen due to static Comparable constraint on K
is HashTreapMap<K, V> -> throw UnsupportedOperationException("lastEntry is not supported for hashed treap maps")
}

/**
Returns the greatest key in this map, or null if the map is empty.
*/
public fun <@Treapable K : Comparable<K>, V> TreapMap<K, V>.lastKey(): K? = lastEntry()?.key
2 changes: 1 addition & 1 deletion collect/src/main/kotlin/com/certora/collect/TreapSet.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import kotlinx.collections.immutable.PersistentSet
A [PersistentSet] implemented as a [Treap](https://en.wikipedia.org/wiki/Treap) - a kind of balanced binary tree.
*/
@Treapable
public interface TreapSet<out T> : PersistentSet<T> {
public sealed interface TreapSet<out T> : PersistentSet<T> {
override fun add(element: @UnsafeVariance T): TreapSet<T>
override fun addAll(elements: Collection<@UnsafeVariance T>): TreapSet<T>
override fun remove(element: @UnsafeVariance T): TreapSet<T>
Expand Down
56 changes: 56 additions & 0 deletions collect/src/test/kotlin/com/certora/collect/SortedTreapMapTest.kt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.certora.collect

import com.certora.collect.*
import kotlin.test.*
import kotlinx.serialization.DeserializationStrategy
import java.util.TreeMap

Expand All @@ -18,4 +19,59 @@ class SortedTreapMapTest: TreapMapTest() {
override fun makeBaselineOfInts(): MutableMap<Int?, Int?> = TreeMap()
override fun getBaseDeserializer(): DeserializationStrategy<*>? = null
override fun getDeserializer(): DeserializationStrategy<*>? = null

@Test
fun ceilingFloorHigherLowerLastFirst() {

val empty = treapMapOf<Int, Int>()
assertNull(empty.firstKey())
assertNull(empty.lastKey())
assertNull(empty.floorKey(0))
assertNull(empty.lowerKey(0))
assertNull(empty.ceilingKey(0))
assertNull(empty.higherKey(0))

val map = treapMapOf<Int, Int>().mutate {
for (i in 0..1000 step 2) {
it[i] = i + 1
}
}

assertEquals(0, map.firstKey())
assertEquals(1000, map.lastKey())

assertNull(map.floorKey(-1))
assertNull(map.lowerKey(-1))
assertEquals(0, map.ceilingKey(-1))
assertEquals(0, map.higherKey(-1))

assertEquals(0, map.floorKey(0))
assertNull(map.lowerKey(0))
assertEquals(0, map.ceilingKey(0))
assertEquals(2, map.higherKey(0))

for (i in 2..998 step 2) {
assertEquals(i, map.floorKey(i))
assertEquals(i - 2, map.lowerKey(i))
assertEquals(i, map.ceilingKey(i))
assertEquals(i + 2, map.higherKey(i))
}

for (i in 1..999 step 2) {
assertEquals(i - 1, map.floorKey(i))
assertEquals(i - 1, map.lowerKey(i))
assertEquals(i + 1, map.ceilingKey(i))
assertEquals(i + 1, map.higherKey(i))
}

assertEquals(1000, map.floorKey(1000))
assertEquals(998, map.lowerKey(1000))
assertEquals(1000, map.ceilingKey(1000))
assertNull(map.higherKey(1000))

assertEquals(1000, map.floorKey(1001))
assertEquals(1000, map.lowerKey(1001))
assertNull(map.ceilingKey(1001))
assertNull(map.higherKey(1001))
}
}

0 comments on commit 04fb5d3

Please sign in to comment.