diff --git a/collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt b/collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt index 4f25a71..ae34a52 100644 --- a/collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt @@ -47,7 +47,7 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT /** Converts the supplied map key to a TreapKey appropriate to this type of AbstractTreapMap (sorted vs. hashed) */ - abstract fun K.toTreapKey(): TreapKey + abstract fun K.toTreapKey(): TreapKey? /** Does this node contain an entry with the given map key? @@ -110,21 +110,21 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT override fun isEmpty(): Boolean = false override fun containsKey(key: K) = - self.find(key.toTreapKey())?.shallowContainsKey(key) ?: false + key.toTreapKey()?.let { self.find(it) }?.shallowContainsKey(key) ?: false override fun containsValue(value: V) = values.contains(value) override fun get(key: K): V? = - self.find(key.toTreapKey())?.shallowGetValue(key) + key.toTreapKey()?.let { self.find(it) }?.shallowGetValue(key) override fun putAll(m: Map): TreapMap = m.entries.fold(this as TreapMap) { t, e -> t.put(e.key, e.value) } - override fun remove(key: K): TreapMap = - self.remove(key.toTreapKey(), key) ?: clear() + override fun remove(key: K): TreapMap = + key.toTreapKey()?.let { self.remove(it, key) ?: clear() } ?: this override fun remove(key: K, value: V): TreapMap = - self.removeEntry(key.toTreapKey(), key, value) ?: clear() + key.toTreapKey()?.let { self.removeEntry(it, key, value) ?: clear() } ?: this override fun clear(): TreapMap = treapMapOf() @@ -275,7 +275,13 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT ``` */ override fun updateEntry(key: K, value: U, merger: (V?, U) -> V?): TreapMap { - return self.updateEntry(key.toTreapKey().precompute(), key, value, merger, ::new) ?: clear() + val treapKey = key.toTreapKey()?.precompute() + return if (treapKey == null) { + // The key is not compatible with this map type, so it's definitely not in the map. + merger(null, value)?.let { put(key, it) } ?: this + } else { + self.updateEntry(treapKey, key, value, merger, ::new) ?: clear() + } } /** diff --git a/collect/src/main/kotlin/com/certora/collect/AbstractTreapSet.kt b/collect/src/main/kotlin/com/certora/collect/AbstractTreapSet.kt index 52fe41d..4900c4e 100644 --- a/collect/src/main/kotlin/com/certora/collect/AbstractTreapSet.kt +++ b/collect/src/main/kotlin/com/certora/collect/AbstractTreapSet.kt @@ -39,7 +39,7 @@ internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet> /** Converts the supplied set element to a TreapKey appropriate to this type of AbstractTreapSet (sorted vs. hashed) */ - abstract fun E.toTreapKey(): TreapKey + abstract fun E.toTreapKey(): TreapKey? /** Does this node contain the element? @@ -92,7 +92,7 @@ internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet> } override fun contains(element: E): Boolean = - self.find(element.toTreapKey())?.shallowContains(element) ?: false + element.toTreapKey()?.let { self.find(it) }?.shallowContains(element) ?: false override fun containsAll(elements: Collection): Boolean = elements.useAsTreap( { elementsTreap -> self.containsAllKeys(elementsTreap) }, @@ -110,7 +110,7 @@ internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet> ) override fun remove(element: E): TreapSet = - self.remove(element.toTreapKey(), element) ?: clear() + element.toTreapKey()?.let { self.remove(it, element) ?: clear() } ?: this override fun removeAll(elements: Collection): TreapSet = elements.useAsTreap( { elementsTreap -> (self difference elementsTreap) ?: clear() }, @@ -134,7 +134,7 @@ internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet> ) override fun findEqual(element: E): E? = - self.find(element.toTreapKey())?.shallowFindEqual(element) + element.toTreapKey()?.let { self.find(it) }?.shallowFindEqual(element) @Suppress("UNCHECKED_CAST") override fun single(): E = getSingleElement() ?: when { diff --git a/collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt b/collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt index 5975ae1..f022933 100644 --- a/collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt @@ -61,10 +61,8 @@ internal class EmptyTreapMap<@Treapable K, V> private constructor() : TreapMap = when (key) { - is PrefersHashTreap -> HashTreapMap(key, value) - is Comparable<*> -> - SortedTreapMap>, V>(key as Comparable>, value) as TreapMap - else -> HashTreapMap(key, value) + !is Comparable<*>?, is PrefersHashTreap -> HashTreapMap(key, value) + else -> SortedTreapMap(key, value) } @Suppress("UNCHECKED_CAST") diff --git a/collect/src/main/kotlin/com/certora/collect/EmptyTreapSet.kt b/collect/src/main/kotlin/com/certora/collect/EmptyTreapSet.kt index e869ca8..4b880b9 100644 --- a/collect/src/main/kotlin/com/certora/collect/EmptyTreapSet.kt +++ b/collect/src/main/kotlin/com/certora/collect/EmptyTreapSet.kt @@ -27,12 +27,9 @@ internal class EmptyTreapSet<@Treapable E> private constructor() : TreapSet, override fun mapReduce(map: (E) -> R, reduce: (R, R) -> R): R? = null override fun parallelMapReduce(map: (E) -> R, reduce: (R, R) -> R, parallelThresholdLog2: Int): R? = null - @Suppress("Treapability", "UNCHECKED_CAST") override fun add(element: E): TreapSet = when (element) { - is PrefersHashTreap -> HashTreapSet(element) - is Comparable<*> -> - SortedTreapSet>>(element as Comparable>) as TreapSet - else -> HashTreapSet(element) + !is Comparable<*>?, is PrefersHashTreap -> HashTreapSet(element) + else -> SortedTreapSet(element as E) } @Suppress("UNCHECKED_CAST") diff --git a/collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt b/collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt index 9fc5325..abddbed 100644 --- a/collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt @@ -21,7 +21,7 @@ internal class HashTreapMap<@Treapable K, V>( override fun hashCode() = computeHashCode() - override fun K.toTreapKey() = TreapKey.Hashed.FromKey(this) + override fun K.toTreapKey() = TreapKey.Hashed.fromKey(this) override fun new(key: K, value: V): HashTreapMap = HashTreapMap(key, value) override fun put(key: K, value: V): TreapMap = self.add(new(key, value)) diff --git a/collect/src/main/kotlin/com/certora/collect/HashTreapSet.kt b/collect/src/main/kotlin/com/certora/collect/HashTreapSet.kt index f886cb6..fd819f0 100644 --- a/collect/src/main/kotlin/com/certora/collect/HashTreapSet.kt +++ b/collect/src/main/kotlin/com/certora/collect/HashTreapSet.kt @@ -17,7 +17,7 @@ internal class HashTreapSet<@Treapable E>( override fun hashCode(): Int = computeHashCode() - override fun E.toTreapKey() = TreapKey.Hashed.FromKey(this) + override fun E.toTreapKey() = TreapKey.Hashed.fromKey(this) override fun new(element: E): HashTreapSet = HashTreapSet(element) override fun add(element: E): TreapSet = self.add(new(element)) diff --git a/collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt b/collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt index 8ce5e19..497cfaa 100644 --- a/collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt @@ -6,21 +6,23 @@ import kotlinx.collections.immutable.PersistentMap A TreapMap specific to Comparable keys. Iterates in the order defined by the objects. We store one element per Treap node, with the map key itself as the Treap key, and an additional `value` field */ -internal class SortedTreapMap<@Treapable K : Comparable?, V>( +internal class SortedTreapMap<@Treapable K, V>( val key: K, val value: V, left: SortedTreapMap? = null, right: SortedTreapMap? = null ) : AbstractTreapMap>(left, right), TreapKey.Sorted { + init { check(key is Comparable<*>?) { "SortedTreapMap keys must be Comparable" } } + override fun hashCode() = computeHashCode() - override fun K.toTreapKey() = TreapKey.Sorted.FromKey(this) + override fun K.toTreapKey() = TreapKey.Sorted.fromKey(this) override fun new(key: K, value: V): SortedTreapMap = SortedTreapMap(key, value) override fun put(key: K, value: V): TreapMap = when (key) { - is PrefersHashTreap -> HashTreapMap(key as K, value) + this + !is Comparable<*>?, is PrefersHashTreap -> HashTreapMap(key, value) + this else -> self.add(new(key, value)) } @@ -98,9 +100,9 @@ internal class SortedTreapMap<@Treapable K : Comparable?, V>( } fun floorEntry(key: K): Map.Entry? { - requireNotNull(key) - val cmp = key.compareTo(this.key) + val cmp = TreapKey.Sorted.fromKey(key)?.compareKeyTo(this) return when { + cmp == null -> null cmp < 0 -> left?.floorEntry(key) cmp > 0 -> right?.floorEntry(key) ?: this.asEntry() else -> this.asEntry() @@ -108,9 +110,9 @@ internal class SortedTreapMap<@Treapable K : Comparable?, V>( } fun ceilingEntry(key: K): Map.Entry? { - requireNotNull(key) - val cmp = key.compareTo(this.key) + val cmp = TreapKey.Sorted.fromKey(key)?.compareKeyTo(this) return when { + cmp == null -> null cmp < 0 -> left?.ceilingEntry(key) ?: this.asEntry() cmp > 0 -> right?.ceilingEntry(key) else -> this.asEntry() @@ -118,18 +120,18 @@ internal class SortedTreapMap<@Treapable K : Comparable?, V>( } fun lowerEntry(key: K): Map.Entry? { - requireNotNull(key) - val cmp = key.compareTo(this.key) + val cmp = TreapKey.Sorted.fromKey(key)?.compareKeyTo(this) return when { + cmp == null -> null cmp > 0 -> right?.lowerEntry(key) ?: this.asEntry() else -> left?.lowerEntry(key) } } fun higherEntry(key: K): Map.Entry? { - requireNotNull(key) - val cmp = key.compareTo(this.key) + val cmp = TreapKey.Sorted.fromKey(key)?.compareKeyTo(this) return when { + cmp == null -> null cmp < 0 -> left?.higherEntry(key) ?: this.asEntry() else -> right?.higherEntry(key) } diff --git a/collect/src/main/kotlin/com/certora/collect/SortedTreapSet.kt b/collect/src/main/kotlin/com/certora/collect/SortedTreapSet.kt index 5eced93..d683b71 100644 --- a/collect/src/main/kotlin/com/certora/collect/SortedTreapSet.kt +++ b/collect/src/main/kotlin/com/certora/collect/SortedTreapSet.kt @@ -6,19 +6,21 @@ import kotlinx.collections.immutable.PersistentSet A TreapSet specific to Comparable elements. Iterates in the order defined by the objects. We store one element per Treap node, with the element itself as the Treap key. */ -internal class SortedTreapSet<@Treapable E : Comparable?>( +internal class SortedTreapSet<@Treapable E>( override val treapKey: E, left: SortedTreapSet? = null, right: SortedTreapSet? = null ) : AbstractTreapSet>(left, right), TreapKey.Sorted { + init { check(treapKey is Comparable<*>?) { "SortedTreapSet elements must be Comparable" } } + override fun hashCode(): Int = computeHashCode() - override fun E.toTreapKey() = TreapKey.Sorted.FromKey(this) + override fun E.toTreapKey() = TreapKey.Sorted.fromKey(this) override fun new(element: E): SortedTreapSet = SortedTreapSet(element) override fun add(element: E): TreapSet = when(element) { - is PrefersHashTreap -> HashTreapSet(element as E) + this + !is Comparable<*>?, is PrefersHashTreap -> HashTreapSet(element) + this else -> self.add(new(element)) } diff --git a/collect/src/main/kotlin/com/certora/collect/TreapKey.kt b/collect/src/main/kotlin/com/certora/collect/TreapKey.kt index d6c02dc..f5b454d 100644 --- a/collect/src/main/kotlin/com/certora/collect/TreapKey.kt +++ b/collect/src/main/kotlin/com/certora/collect/TreapKey.kt @@ -64,11 +64,9 @@ internal interface TreapKey<@Treapable K> { /** A TreapKey whose underlying key implements Comparable. This allows us to sort the Treap naturally. */ - interface Sorted<@Treapable K : Comparable?> : TreapKey { + interface Sorted<@Treapable K> : TreapKey { abstract override val treapKey: K - // Note that we must never compare a Hashed key with a Sorted key. We'd check that here, but this is extremely - // perf-critical code. override fun compareKeyTo(that: TreapKey): Int { val thisTreapKey = this.treapKey val thatTreapKey = that.treapKey @@ -76,14 +74,23 @@ internal interface TreapKey<@Treapable K> { thisTreapKey === thatTreapKey -> 0 thisTreapKey == null -> -1 thatTreapKey == null -> 1 - else -> thisTreapKey.compareTo(thatTreapKey) + else -> { + @Suppress("UNCHECKED_CAST") + (thisTreapKey as Comparable).compareTo(thatTreapKey) + } } } - override fun precompute() = FromKey(treapKey) + override fun precompute() = fromKey(treapKey)!! - class FromKey<@Treapable K : Comparable?>(override val treapKey: K) : Sorted { - override val treapPriority = super.treapPriority // precompute the priority + companion object { + fun <@Treapable K> fromKey(key: K): Sorted? = when (key) { + is Comparable<*>? -> object : Sorted { + override val treapKey = key + override val treapPriority = super.treapPriority // precompute the priority + } + else -> null + } } } @@ -107,11 +114,14 @@ internal interface TreapKey<@Treapable K> { } } - override fun precompute() = FromKey(treapKey) + override fun precompute() = fromKey(treapKey) - class FromKey<@Treapable K>(override val treapKey: K) : Hashed { - override val treapKeyHashCode = treapKey.hashCode() // precompute the hash code - override val treapPriority = super.treapPriority // precompute the priority + companion object { + fun <@Treapable K> fromKey(key: K): Hashed = object : Hashed { + override val treapKey = key + override val treapKeyHashCode = treapKey.hashCode() // precompute the hash code + override val treapPriority = super.treapPriority // precompute the priority + } } } }