Skip to content

Commit

Permalink
More nullable stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
ericeil committed May 16, 2024
1 parent 05a39a1 commit a222ec9
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 57 deletions.
97 changes: 44 additions & 53 deletions collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import kotlinx.collections.immutable.ImmutableSet
internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractTreapMap<K, V, S>>(
left: S?,
right: S?
) : TreapMap<K, V>, Treap<K, S>(left, right) {
) : TreapMap<K, V>, Treap<K, S>(left, right), TreapKey<K> {

/**
Derived classes override to create an apropriate node containing the given entry.
Expand All @@ -25,23 +25,6 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT
*/
abstract fun Map<out K, V>.toTreapMapOrNull(): AbstractTreapMap<K, V, S>?

/**
Converts the given Map to a AbstractTreapMap of the same type as 'this'. May copy the map.
*/
fun Map<out K, V>.toTreapMapIfNotEmpty(): AbstractTreapMap<K, V, S>? =
toTreapMapOrNull() ?: when {
isEmpty() -> null
else -> {
val i = entries.iterator()
var m: AbstractTreapMap<K, V, S> = i.next().let { (k, v) -> new(k, v) }
while (i.hasNext()) {
val (k, v) = i.next()
m = m.put(k, v)
}
m
}
}

/**
Given a map, calls the supplied `action` if the collection is a Treap of the same type as this Treap, otherwise
calls `fallback.` Used to implement optimized operations over two compatible Treaps, with a fallback when
Expand Down Expand Up @@ -134,9 +117,6 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT
override fun get(key: K): V? =
self.find(key.toTreapKey())?.shallowGetValue(key)

override fun put(key: K, value: V): AbstractTreapMap<K, V, S> =
self.add(new(key, value))

override fun putAll(m: Map<out K, V>): TreapMap<K, V> =
m.entries.fold(this as TreapMap<K, V>) { t, e -> t.put(e.key, e.value) }

Expand Down Expand Up @@ -305,47 +285,58 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT
override fun zip(m: Map<out K, V>) = sequence<Map.Entry<K, Pair<V?, V?>>> {
fun <T> Iterator<T>.nextOrNull() = if (hasNext()) { next() } else { null }

// Iterate over the two maps' treap sequences. We ensure that each sequence uses the same key ordering, by
// converting `m` to a TreapMap of this map's type, if necessary. Note that we can't use entrySequence, because
// HashTreapMap's entrySequence is only partially ordered.
val thisIt = asTreapSequence().iterator()
val that: Treap<K, S>? = m.toTreapMapIfNotEmpty()
val thatIt = that?.asTreapSequence()?.iterator()

var thisCurrent = thisIt.nextOrNull()
var thatCurrent = thatIt?.nextOrNull()

while (thisCurrent != null && thatIt != null && thatCurrent != null) {
val c = thisCurrent.compareKeyTo(thatCurrent)
when {
c < 0 -> {
yieldAll(thisCurrent.shallowZipThisOnly())
thisCurrent = thisIt.nextOrNull()
}
c > 0 -> {
yieldAll(thatCurrent.shallowZipThatOnly())
thatCurrent = thatIt.nextOrNull()
val sequences = getTreapSequencesIfSameType(m)
if (sequences != null) {
// Fast case for when the maps are the same type
val thisIt = sequences.first.iterator()
val thatIt = sequences.second.iterator()

var thisCurrent = thisIt.nextOrNull()
var thatCurrent = thatIt.nextOrNull()

while (thisCurrent != null && thatCurrent != null) {
val c = thisCurrent.compareKeyTo(thatCurrent)
when {
c < 0 -> {
yieldAll(thisCurrent.shallowZipThisOnly())
thisCurrent = thisIt.nextOrNull()
}
c > 0 -> {
yieldAll(thatCurrent.shallowZipThatOnly())
thatCurrent = thatIt.nextOrNull()
}
else -> {
yieldAll(thisCurrent.shallowZip(thatCurrent))
thisCurrent = thisIt.nextOrNull()
thatCurrent = thatIt.nextOrNull()
}
}
else -> {
yieldAll(thisCurrent.shallowZip(thatCurrent))
thisCurrent = thisIt.nextOrNull()
thatCurrent = thatIt.nextOrNull()
}
while (thisCurrent != null) {
yieldAll(thisCurrent.shallowZipThisOnly())
thisCurrent = thisIt.nextOrNull()
}
while (thatCurrent != null) {
yieldAll(thatCurrent.shallowZipThatOnly())
thatCurrent = thatIt.nextOrNull()
}
} else {
// Slower fallback for maps of different types
for ((k, v) in entries) {
yield(MapEntry(k, v to m[k]))
}
for ((k, v) in m.entries) {
if (k !in this@AbstractTreapMap) {
yield(MapEntry(k, null to v))
}
}
}
while (thisCurrent != null) {
yieldAll(thisCurrent.shallowZipThisOnly())
thisCurrent = thisIt.nextOrNull()
}
while (thatIt != null && thatCurrent != null) {
yieldAll(thatCurrent.shallowZipThatOnly())
thatCurrent = thatIt.nextOrNull()
}
}

private fun shallowZipThisOnly() = shallowEntrySequence().map { MapEntry(it.key, it.value to null) }
private fun shallowZipThatOnly() = shallowEntrySequence().map { MapEntry(it.key, null to it.value) }
protected abstract fun shallowZip(that: S): Sequence<Map.Entry<K, Pair<V?, V?>>>
protected abstract fun getTreapSequencesIfSameType(that: Map<out K, V>): Pair<Sequence<S>, Sequence<S>>?

override fun <R : Any> mapReduce(map: (K, V) -> R, reduce: (R, R) -> R): R =
notForking(self) { mapReduceImpl(map, reduce) }
Expand Down
11 changes: 11 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ internal class HashTreapMap<@Treapable K, V>(
override fun K.toTreapKey() = TreapKey.Hashed.FromKey(this)
override fun new(key: K, value: V): HashTreapMap<K, V> = HashTreapMap(key, value)

override fun put(key: K, value: V): TreapMap<K, V> = self.add(new(key, value))

@Suppress("UNCHECKED_CAST")
override fun Map<out K, V>.toTreapMapOrNull() =
this as? HashTreapMap<K, V>
Expand Down Expand Up @@ -88,6 +90,15 @@ internal class HashTreapMap<@Treapable K, V>(
return false
}

protected override fun getTreapSequencesIfSameType(
that: Map<out K, V>
): Pair<Sequence<HashTreapMap<K, V>>, Sequence<HashTreapMap<K, V>>>? {
@Suppress("UNCHECKED_CAST")
return (that as? HashTreapMap<K, V>)?.let {
this.asTreapSequence() to it.asTreapSequence()
}
}

override fun shallowZip(that: HashTreapMap<K, V>): Sequence<Map.Entry<K, Pair<V?, V?>>> = sequence {
forEachPair {
yield(MapEntry(it.key, it.value to that.shallowGetValue(it.key)))
Expand Down
22 changes: 20 additions & 2 deletions collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import kotlinx.collections.immutable.PersistentMap
A TreapMap specific to Comparable keys. Iterates in the order defined by the objects. We store one element per
Treap node, with the map key itself as the Treap key, and an additional `value` field
*/
internal class SortedTreapMap<@Treapable K : Comparable<K>, V>(
internal class SortedTreapMap<@Treapable K : Comparable<K>?, V>(
val key: K,
val value: V,
left: SortedTreapMap<K, V>? = null,
Expand All @@ -19,13 +19,18 @@ internal class SortedTreapMap<@Treapable K : Comparable<K>, V>(

override fun new(key: K, value: V): SortedTreapMap<K, V> = SortedTreapMap(key, value)

override fun put(key: K, value: V): TreapMap<K, V> = when (key) {
is PrefersHashTreap -> HashTreapMap(key as K, value) + this
else -> self.add(new(key, value))
}

@Suppress("UNCHECKED_CAST")
override fun Map<out K, V>.toTreapMapOrNull() =
this as? SortedTreapMap<K, V>
?: (this as? PersistentMap.Builder<K, V>)?.build() as? SortedTreapMap<K, V>

override fun getShallowMerger(merger: (K, V?, V?) -> V?): (SortedTreapMap<K, V>?, SortedTreapMap<K, V>?) -> SortedTreapMap<K, V>? = { t1, t2 ->
val k = t1?.key ?: t2?.key as K
val k = (t1 ?: t2)?.key!!
val v1 = t1?.value
val v2 = t2?.value
val v = merger(k, v1, v2)
Expand All @@ -37,6 +42,15 @@ internal class SortedTreapMap<@Treapable K : Comparable<K>, V>(
}
}

protected override fun getTreapSequencesIfSameType(
that: Map<out K, V>
): Pair<Sequence<SortedTreapMap<K, V>>, Sequence<SortedTreapMap<K, V>>>? {
@Suppress("UNCHECKED_CAST")
return (that as? SortedTreapMap<K, V>)?.let {
this.asTreapSequence() to it.asTreapSequence()
}
}

override fun shallowZip(that: SortedTreapMap<K, V>): Sequence<Map.Entry<K, Pair<V, V>>> =
sequenceOf(MapEntry(this.key, this.value to that.value))

Expand Down Expand Up @@ -84,6 +98,7 @@ internal class SortedTreapMap<@Treapable K : Comparable<K>, V>(
}

fun floorEntry(key: K): Map.Entry<K, V>? {
requireNotNull(key)
val cmp = key.compareTo(this.key)
return when {
cmp < 0 -> left?.floorEntry(key)
Expand All @@ -93,6 +108,7 @@ internal class SortedTreapMap<@Treapable K : Comparable<K>, V>(
}

fun ceilingEntry(key: K): Map.Entry<K, V>? {
requireNotNull(key)
val cmp = key.compareTo(this.key)
return when {
cmp < 0 -> left?.ceilingEntry(key) ?: this.asEntry()
Expand All @@ -102,6 +118,7 @@ internal class SortedTreapMap<@Treapable K : Comparable<K>, V>(
}

fun lowerEntry(key: K): Map.Entry<K, V>? {
requireNotNull(key)
val cmp = key.compareTo(this.key)
return when {
cmp > 0 -> right?.lowerEntry(key) ?: this.asEntry()
Expand All @@ -110,6 +127,7 @@ internal class SortedTreapMap<@Treapable K : Comparable<K>, V>(
}

fun higherEntry(key: K): Map.Entry<K, V>? {
requireNotNull(key)
val cmp = key.compareTo(this.key)
return when {
cmp < 0 -> left?.higherEntry(key) ?: this.asEntry()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ internal class SortedTreapSet<@Treapable E : Comparable<E>?>(
override fun new(element: E): SortedTreapSet<E> = SortedTreapSet(element)

override fun add(element: E): TreapSet<E> = when(element) {
null, is PrefersHashTreap -> HashTreapSet(element) + this
is PrefersHashTreap -> HashTreapSet(element as E) + this
else -> self.add(new(element))
}

Expand Down
11 changes: 10 additions & 1 deletion collect/src/main/kotlin/com/certora/collect/TreapKey.kt
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,16 @@ internal interface TreapKey<@Treapable K> {

// Note that we must never compare a Hashed key with a Sorted key. We'd check that here, but this is extremely
// perf-critical code.
override fun compareKeyTo(that: TreapKey<K>) = this.treapKey!!.compareTo(that.treapKey)
override fun compareKeyTo(that: TreapKey<K>): Int {
val thisTreapKey = this.treapKey
val thatTreapKey = that.treapKey
return when {
thisTreapKey === thatTreapKey -> 0
thisTreapKey == null -> -1
thatTreapKey == null -> 1
else -> thisTreapKey.compareTo(thatTreapKey)
}
}

override fun precompute() = FromKey(treapKey)

Expand Down
32 changes: 32 additions & 0 deletions collect/src/test/kotlin/com/certora/collect/TreapMapTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,38 @@ abstract class TreapMapTest {
}
}

@Test
fun addNullMapAtEnd() {
val s = makeMap()
s.putAll(treapMapOf(makeKey(0) to null))
s.putAll(treapMapOf(null to null))
assertEquals(s, mapOf<TestKey?, Any?>(null to null, makeKey(0) to null))
}

@Test
fun addNullMapAtStart() {
val s = makeMap()
s.putAll(treapMapOf(null to null))
s.putAll(treapMapOf(makeKey(0) to null))
assertEquals(s, mapOf<TestKey?, Any?>(null to null, makeKey(0) to null))
}

@Test
fun addNullKeyAtEnd() {
val s = makeMap()
s.put(makeKey(0), null)
s.put(null, null)
assertEquals(s, mapOf<TestKey?, Any?>(null to null, makeKey(0) to null))
}

@Test
fun addNullKeyAtStart() {
val s = makeMap()
s.put(null, null)
s.put(makeKey(0), null)
assertEquals(s, mapOf<TestKey?, Any?>(null to null, makeKey(0) to null))
}

@Test
fun copyConstructorEmpty() {
val empty = mapOf<TestKey?, Any?>()
Expand Down

0 comments on commit a222ec9

Please sign in to comment.