Skip to content

Commit

Permalink
Fix more type issues
Browse files Browse the repository at this point in the history
  • Loading branch information
ericeil committed May 16, 2024
1 parent 57a5fdc commit e0007af
Show file tree
Hide file tree
Showing 9 changed files with 62 additions and 47 deletions.
20 changes: 13 additions & 7 deletions collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT
/**
Converts the supplied map key to a TreapKey appropriate to this type of AbstractTreapMap (sorted vs. hashed)
*/
abstract fun K.toTreapKey(): TreapKey<K>
abstract fun K.toTreapKey(): TreapKey<K>?

/**
Does this node contain an entry with the given map key?
Expand Down Expand Up @@ -110,21 +110,21 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT
override fun isEmpty(): Boolean = false

override fun containsKey(key: K) =
self.find(key.toTreapKey())?.shallowContainsKey(key) ?: false
key.toTreapKey()?.let { self.find(it) }?.shallowContainsKey(key) ?: false

override fun containsValue(value: V) = values.contains(value)

override fun get(key: K): V? =
self.find(key.toTreapKey())?.shallowGetValue(key)
key.toTreapKey()?.let { self.find(it) }?.shallowGetValue(key)

override fun putAll(m: Map<out K, V>): TreapMap<K, V> =
m.entries.fold(this as TreapMap<K, V>) { t, e -> t.put(e.key, e.value) }

override fun remove(key: K): TreapMap<K, V> =
self.remove(key.toTreapKey(), key) ?: clear()
override fun remove(key: K): TreapMap<K, V> =
key.toTreapKey()?.let { self.remove(it, key) ?: clear() } ?: this

override fun remove(key: K, value: V): TreapMap<K, V> =
self.removeEntry(key.toTreapKey(), key, value) ?: clear()
key.toTreapKey()?.let { self.removeEntry(it, key, value) ?: clear() } ?: this

override fun clear(): TreapMap<K, V> = treapMapOf<K, V>()

Expand Down Expand Up @@ -275,7 +275,13 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT
```
*/
override fun <U> updateEntry(key: K, value: U, merger: (V?, U) -> V?): TreapMap<K, V> {
return self.updateEntry(key.toTreapKey().precompute(), key, value, merger, ::new) ?: clear()
val treapKey = key.toTreapKey()?.precompute()
return if (treapKey == null) {
// The key is not compatible with this map type, so it's definitely not in the map.
merger(null, value)?.let { put(key, it) } ?: this
} else {
self.updateEntry(treapKey, key, value, merger, ::new) ?: clear()
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet<E, S>>
/**
Converts the supplied set element to a TreapKey appropriate to this type of AbstractTreapSet (sorted vs. hashed)
*/
abstract fun E.toTreapKey(): TreapKey<E>
abstract fun E.toTreapKey(): TreapKey<E>?

/**
Does this node contain the element?
Expand Down Expand Up @@ -92,7 +92,7 @@ internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet<E, S>>
}

override fun contains(element: E): Boolean =
self.find(element.toTreapKey())?.shallowContains(element) ?: false
element.toTreapKey()?.let { self.find(it) }?.shallowContains(element) ?: false

override fun containsAll(elements: Collection<E>): Boolean = elements.useAsTreap(
{ elementsTreap -> self.containsAllKeys(elementsTreap) },
Expand All @@ -110,7 +110,7 @@ internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet<E, S>>
)

override fun remove(element: E): TreapSet<E> =
self.remove(element.toTreapKey(), element) ?: clear()
element.toTreapKey()?.let { self.remove(it, element) ?: clear() } ?: this

override fun removeAll(elements: Collection<E>): TreapSet<E> = elements.useAsTreap(
{ elementsTreap -> (self difference elementsTreap) ?: clear() },
Expand All @@ -134,7 +134,7 @@ internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet<E, S>>
)

override fun findEqual(element: E): E? =
self.find(element.toTreapKey())?.shallowFindEqual(element)
element.toTreapKey()?.let { self.find(it) }?.shallowFindEqual(element)

@Suppress("UNCHECKED_CAST")
override fun single(): E = getSingleElement() ?: when {
Expand Down
6 changes: 2 additions & 4 deletions collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,8 @@ internal class EmptyTreapMap<@Treapable K, V> private constructor() : TreapMap<K

@Suppress("Treapability", "UNCHECKED_CAST")
override fun put(key: K, value: V): TreapMap<K, V> = when (key) {
is PrefersHashTreap -> HashTreapMap(key, value)
is Comparable<*> ->
SortedTreapMap<Comparable<Comparable<*>>, V>(key as Comparable<Comparable<*>>, value) as TreapMap<K, V>
else -> HashTreapMap(key, value)
!is Comparable<*>?, is PrefersHashTreap -> HashTreapMap(key, value)
else -> SortedTreapMap(key, value)
}

@Suppress("UNCHECKED_CAST")
Expand Down
7 changes: 2 additions & 5 deletions collect/src/main/kotlin/com/certora/collect/EmptyTreapSet.kt
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,9 @@ internal class EmptyTreapSet<@Treapable E> private constructor() : TreapSet<E>,
override fun <R : Any> mapReduce(map: (E) -> R, reduce: (R, R) -> R): R? = null
override fun <R : Any> parallelMapReduce(map: (E) -> R, reduce: (R, R) -> R, parallelThresholdLog2: Int): R? = null

@Suppress("Treapability", "UNCHECKED_CAST")
override fun add(element: E): TreapSet<E> = when (element) {
is PrefersHashTreap -> HashTreapSet(element)
is Comparable<*> ->
SortedTreapSet<Comparable<Comparable<*>>>(element as Comparable<Comparable<*>>) as TreapSet<E>
else -> HashTreapSet(element)
!is Comparable<*>?, is PrefersHashTreap -> HashTreapSet(element)
else -> SortedTreapSet(element as E)
}

@Suppress("UNCHECKED_CAST")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ internal class HashTreapMap<@Treapable K, V>(

override fun hashCode() = computeHashCode()

override fun K.toTreapKey() = TreapKey.Hashed.FromKey(this)
override fun K.toTreapKey() = TreapKey.Hashed.fromKey(this)
override fun new(key: K, value: V): HashTreapMap<K, V> = HashTreapMap(key, value)

override fun put(key: K, value: V): TreapMap<K, V> = self.add(new(key, value))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ internal class HashTreapSet<@Treapable E>(

override fun hashCode(): Int = computeHashCode()

override fun E.toTreapKey() = TreapKey.Hashed.FromKey(this)
override fun E.toTreapKey() = TreapKey.Hashed.fromKey(this)
override fun new(element: E): HashTreapSet<E> = HashTreapSet(element)

override fun add(element: E): TreapSet<E> = self.add(new(element))
Expand Down
24 changes: 13 additions & 11 deletions collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,23 @@ import kotlinx.collections.immutable.PersistentMap
A TreapMap specific to Comparable keys. Iterates in the order defined by the objects. We store one element per
Treap node, with the map key itself as the Treap key, and an additional `value` field
*/
internal class SortedTreapMap<@Treapable K : Comparable<K>?, V>(
internal class SortedTreapMap<@Treapable K, V>(
val key: K,
val value: V,
left: SortedTreapMap<K, V>? = null,
right: SortedTreapMap<K, V>? = null
) : AbstractTreapMap<K, V, SortedTreapMap<K, V>>(left, right), TreapKey.Sorted<K> {

init { check(key is Comparable<*>?) { "SortedTreapMap keys must be Comparable" } }

override fun hashCode() = computeHashCode()

override fun K.toTreapKey() = TreapKey.Sorted.FromKey(this)
override fun K.toTreapKey() = TreapKey.Sorted.fromKey(this)

override fun new(key: K, value: V): SortedTreapMap<K, V> = SortedTreapMap(key, value)

override fun put(key: K, value: V): TreapMap<K, V> = when (key) {
is PrefersHashTreap -> HashTreapMap(key as K, value) + this
!is Comparable<*>?, is PrefersHashTreap -> HashTreapMap(key, value) + this
else -> self.add(new(key, value))
}

Expand Down Expand Up @@ -98,38 +100,38 @@ internal class SortedTreapMap<@Treapable K : Comparable<K>?, V>(
}

fun floorEntry(key: K): Map.Entry<K, V>? {
requireNotNull(key)
val cmp = key.compareTo(this.key)
val cmp = TreapKey.Sorted.fromKey(key)?.compareKeyTo(this)
return when {
cmp == null -> null
cmp < 0 -> left?.floorEntry(key)
cmp > 0 -> right?.floorEntry(key) ?: this.asEntry()
else -> this.asEntry()
}
}

fun ceilingEntry(key: K): Map.Entry<K, V>? {
requireNotNull(key)
val cmp = key.compareTo(this.key)
val cmp = TreapKey.Sorted.fromKey(key)?.compareKeyTo(this)
return when {
cmp == null -> null
cmp < 0 -> left?.ceilingEntry(key) ?: this.asEntry()
cmp > 0 -> right?.ceilingEntry(key)
else -> this.asEntry()
}
}

fun lowerEntry(key: K): Map.Entry<K, V>? {
requireNotNull(key)
val cmp = key.compareTo(this.key)
val cmp = TreapKey.Sorted.fromKey(key)?.compareKeyTo(this)
return when {
cmp == null -> null
cmp > 0 -> right?.lowerEntry(key) ?: this.asEntry()
else -> left?.lowerEntry(key)
}
}

fun higherEntry(key: K): Map.Entry<K, V>? {
requireNotNull(key)
val cmp = key.compareTo(this.key)
val cmp = TreapKey.Sorted.fromKey(key)?.compareKeyTo(this)
return when {
cmp == null -> null
cmp < 0 -> left?.higherEntry(key) ?: this.asEntry()
else -> right?.higherEntry(key)
}
Expand Down
8 changes: 5 additions & 3 deletions collect/src/main/kotlin/com/certora/collect/SortedTreapSet.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,21 @@ import kotlinx.collections.immutable.PersistentSet
A TreapSet specific to Comparable elements. Iterates in the order defined by the objects. We store one element per
Treap node, with the element itself as the Treap key.
*/
internal class SortedTreapSet<@Treapable E : Comparable<E>?>(
internal class SortedTreapSet<@Treapable E>(
override val treapKey: E,
left: SortedTreapSet<E>? = null,
right: SortedTreapSet<E>? = null
) : AbstractTreapSet<E, SortedTreapSet<E>>(left, right), TreapKey.Sorted<E> {

init { check(treapKey is Comparable<*>?) { "SortedTreapSet elements must be Comparable" } }

override fun hashCode(): Int = computeHashCode()

override fun E.toTreapKey() = TreapKey.Sorted.FromKey(this)
override fun E.toTreapKey() = TreapKey.Sorted.fromKey(this)
override fun new(element: E): SortedTreapSet<E> = SortedTreapSet(element)

override fun add(element: E): TreapSet<E> = when(element) {
is PrefersHashTreap -> HashTreapSet(element as E) + this
!is Comparable<*>?, is PrefersHashTreap -> HashTreapSet(element) + this
else -> self.add(new(element))
}

Expand Down
32 changes: 21 additions & 11 deletions collect/src/main/kotlin/com/certora/collect/TreapKey.kt
Original file line number Diff line number Diff line change
Expand Up @@ -64,26 +64,33 @@ internal interface TreapKey<@Treapable K> {
/**
A TreapKey whose underlying key implements Comparable. This allows us to sort the Treap naturally.
*/
interface Sorted<@Treapable K : Comparable<K>?> : TreapKey<K> {
interface Sorted<@Treapable K> : TreapKey<K> {
abstract override val treapKey: K

// Note that we must never compare a Hashed key with a Sorted key. We'd check that here, but this is extremely
// perf-critical code.
override fun compareKeyTo(that: TreapKey<K>): Int {
val thisTreapKey = this.treapKey
val thatTreapKey = that.treapKey
return when {
thisTreapKey === thatTreapKey -> 0
thisTreapKey == null -> -1
thatTreapKey == null -> 1
else -> thisTreapKey.compareTo(thatTreapKey)
else -> {
@Suppress("UNCHECKED_CAST")
(thisTreapKey as Comparable<K>).compareTo(thatTreapKey)
}
}
}

override fun precompute() = FromKey(treapKey)
override fun precompute() = fromKey(treapKey)!!

class FromKey<@Treapable K : Comparable<K>?>(override val treapKey: K) : Sorted<K> {
override val treapPriority = super.treapPriority // precompute the priority
companion object {
fun <@Treapable K> fromKey(key: K): Sorted<K>? = when (key) {
is Comparable<*>? -> object : Sorted<K> {
override val treapKey = key
override val treapPriority = super.treapPriority // precompute the priority
}
else -> null
}
}
}

Expand All @@ -107,11 +114,14 @@ internal interface TreapKey<@Treapable K> {
}
}

override fun precompute() = FromKey(treapKey)
override fun precompute() = fromKey(treapKey)

class FromKey<@Treapable K>(override val treapKey: K) : Hashed<K> {
override val treapKeyHashCode = treapKey.hashCode() // precompute the hash code
override val treapPriority = super.treapPriority // precompute the priority
companion object {
fun <@Treapable K> fromKey(key: K): Hashed<K> = object : Hashed<K> {
override val treapKey = key
override val treapKeyHashCode = treapKey.hashCode() // precompute the hash code
override val treapPriority = super.treapPriority // precompute the priority
}
}
}
}

0 comments on commit e0007af

Please sign in to comment.