Skip to content

Commit

Permalink
TreapMap.keys as TreapSet (#20)
Browse files Browse the repository at this point in the history
Currently it is quite expensive to do operations like this:

```
set2 = map1.keys intersect set1
```

This is because the `keys` property on `TreapMap` is an `ImmuableSet`,
not a `TreapSet`, so the intersection operation must convert it to a
`TreapSet` first, using the generic `Set`->`TreapSet` conversion, which
is rather expensive.

This is a shame, because a `TreapMap<K, V>` of course is already
structured the same way as a `TreapSet<K>`, so the conversion to
`TreapSet` should be very cheap.

This PR redefines the `keys` property as a `TreapSet`, and defines
efficient implementations of this. Depending on usage, we can sometimes
avoid the conversion to `TreapSet` altogether; if not, we can do the
conversion as a straightforward projection of the existing `TreapMap`
structure, avoiding most of the overhead of the generic conversion.

We also add `single` and `singleOrNull` methods to `TreapMap`, which are
useful in their own right, and can be used to avoid the conversion to
`TreapSet` in some cases.
  • Loading branch information
ericeil authored Jan 7, 2025
1 parent 78857f0 commit e2d8c8e
Show file tree
Hide file tree
Showing 10 changed files with 122 additions and 30 deletions.
59 changes: 59 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/AbstractKeySet.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package com.certora.collect

/**
Presents the keys of a [TreapMap] as a [TreapSet].
The idea here is that a `TreapMap<K, *>` is stored with the same Treap structure as a `TreapSet<K>`, so we can very
quickly create the corresponding `TreapSet<K>` when needed, in O(n) time (as opposed to the naive O(n*log(n))
method).
We lazily initialize the set, so that we don't create it until we need it. For many operations, we can avoid
creating the set entirely, and just use the map directly. However, many operations, e.g. [addAll]/[union] and
[retainAll/intersect], are much more efficient when we have a [TreapSet], so we create it when needed.
*/
internal abstract class AbstractKeySet<@Treapable K, S : TreapSet<K>> : TreapSet<K> {
/**
The map whose keys we are presenting as a set. We prefer to use the map directly when possible, so we don't
need to create the set.
*/
abstract val map: AbstractTreapMap<K, *, *>
/**
The set of keys. This is a lazy property so that we don't create the set until we need it.
*/
abstract val keys: Lazy<S>

@Suppress("Treapability")
override fun hashCode() = keys.value.hashCode()
override fun equals(other: Any?) = keys.value.equals(other)
override fun toString() = keys.value.toString()

override val size get() = map.size
override fun isEmpty() = map.isEmpty()
override fun clear() = treapSetOf<K>()

override operator fun contains(element: K) = map.containsKey(element)
override operator fun iterator() = map.entrySequence().map { it.key }.iterator()

override fun add(element: K) = keys.value.add(element)
override fun addAll(elements: Collection<K>) = keys.value.addAll(elements)
override fun remove(element: K) = keys.value.remove(element)
override fun removeAll(elements: Collection<K>) = keys.value.removeAll(elements)
override fun removeAll(predicate: (K) -> Boolean) = keys.value.removeAll(predicate)
override fun retainAll(elements: Collection<K>) = keys.value.retainAll(elements)

override fun single() = map.single().key
override fun singleOrNull() = map.singleOrNull()?.key
override fun arbitraryOrNull() = map.arbitraryOrNull()?.key

override fun containsAny(elements: Iterable<K>) = keys.value.containsAny(elements)
override fun containsAny(predicate: (K) -> Boolean) = (this as Iterable<K>).any(predicate)
override fun containsAll(elements: Collection<K>) = keys.value.containsAll(elements)
override fun findEqual(element: K) = keys.value.findEqual(element)

override fun forEachElement(action: (K) -> Unit) = map.forEachEntry { action(it.key) }

override fun <R : Any> mapReduce(map: (K) -> R, reduce: (R, R) -> R) =
this.map.mapReduce({ k, _ -> map(k) }, reduce)
override fun <R : Any> parallelMapReduce(map: (K) -> R, reduce: (R, R) -> R, parallelThresholdLog2: Int) =
this.map.parallelMapReduce({ k, _ -> map(k) }, reduce, parallelThresholdLog2)
}
12 changes: 4 additions & 8 deletions collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT
return when {
otherMap == null -> false
otherMap === this -> true
otherMap.isEmpty() -> false // NB AbstractTreapMap always contains at least one entry
else -> otherMap.useAsTreap(
{ otherTreap -> this.self.deepEquals(otherTreap) },
{ other.size == this.size && other.entries.all { this.containsEntry(it) }}
Expand All @@ -112,6 +113,9 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT
override val size: Int get() = computeSize()
override fun isEmpty(): Boolean = false

// NB AbstractTreapMap always contains at least one entry
override fun single() = singleOrNull() ?: throw IllegalArgumentException("Map contains more than one entry")

override fun containsKey(key: K) =
key.toTreapKey()?.let { self.find(it) }?.shallowContainsKey(key) ?: false

Expand Down Expand Up @@ -140,14 +144,6 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT
override fun iterator() = entrySequence().iterator()
}

override val keys: ImmutableSet<K>
get() = object: AbstractSet<K>(), ImmutableSet<K> {
override val size get() = this@AbstractTreapMap.size
override fun isEmpty() = this@AbstractTreapMap.isEmpty()
override operator fun contains(element: K) = containsKey(element)
override operator fun iterator() = entrySequence().map { it.key }.iterator()
}

override val values: ImmutableCollection<V>
get() = object: AbstractCollection<V>(), ImmutableCollection<V> {
override val size get() = this@AbstractTreapMap.size
Expand Down
19 changes: 2 additions & 17 deletions collect/src/main/kotlin/com/certora/collect/AbstractTreapSet.kt
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet<E, S>>
*/
abstract fun shallowForEach(action: (element: E) -> Unit): Unit

abstract fun shallowGetSingleElement(): E?

abstract infix fun shallowUnion(that: S): S
abstract infix fun shallowIntersect(that: S): S?
abstract infix fun shallowDifference(that: S): S?
Expand Down Expand Up @@ -85,6 +83,7 @@ internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet<E, S>>
other == null -> false
this === other -> true
other !is Set<*> -> false
other.isEmpty() -> false // NB AbstractTreapSet always contains at least one element
else -> (other as Set<E>).useAsTreap(
{ otherTreap -> this.self.deepEquals(otherTreap) },
{ this.size == other.size && this.containsAll(other) }
Expand Down Expand Up @@ -136,26 +135,12 @@ internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet<E, S>>
override fun findEqual(element: E): E? =
element.toTreapKey()?.let { self.find(it) }?.shallowFindEqual(element)

@Suppress("UNCHECKED_CAST")
override fun single(): E = getSingleElement() ?: when {
isEmpty() -> throw NoSuchElementException("Set is empty")
size > 1 -> throw IllegalArgumentException("Set has more than one element")
else -> null as E // The single element must have been null!
}

override fun singleOrNull(): E? = getSingleElement()

override fun forEachElement(action: (element: E) -> Unit): Unit {
left?.forEachElement(action)
shallowForEach(action)
right?.forEachElement(action)
}

internal fun getSingleElement(): E? = when {
left === null && right === null -> shallowGetSingleElement()
else -> null
}

override fun <R : Any> mapReduce(map: (E) -> R, reduce: (R, R) -> R): R =
notForking(self) { mapReduceImpl(map, reduce) }

Expand Down Expand Up @@ -186,7 +171,7 @@ internal infix fun <@Treapable E, S : AbstractTreapSet<E, S>> S?.treapUnion(that
this == null -> that
that == null -> this
this === that -> this
that.getSingleElement() != null -> add(that)
that.singleOrNull() != null -> add(that)
else -> {
// remember, a.comparePriorityTo(b)==0 <=> a.compareKeyTo(b)==0
val c = this.comparePriorityTo(that)
Expand Down
4 changes: 3 additions & 1 deletion collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ internal class EmptyTreapMap<@Treapable K, V> private constructor() : TreapMap<K
override fun remove(key: K): TreapMap<K, V> = this
override fun remove(key: K, value: V): TreapMap<K, V> = this

override fun single(): Map.Entry<K, V> = throw NoSuchElementException("Empty map.")
override fun singleOrNull(): Map.Entry<K, V>? = null
override fun arbitraryOrNull(): Map.Entry<K, V>? = null

override fun forEachEntry(action: (Map.Entry<K, V>) -> Unit): Unit {}
Expand Down Expand Up @@ -73,7 +75,7 @@ internal class EmptyTreapMap<@Treapable K, V> private constructor() : TreapMap<K
m.asSequence().map { MapEntry(it.key, null to it.value) }

override val entries: ImmutableSet<Map.Entry<K, V>> get() = persistentSetOf<Map.Entry<K, V>>()
override val keys: ImmutableSet<K> get() = persistentSetOf<K>()
override val keys: TreapSet<K> get() = treapSetOf<K>()
override val values: ImmutableCollection<V> get() = persistentSetOf<V>()

@Suppress("Treapability", "UNCHECKED_CAST")
Expand Down
14 changes: 14 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ internal class HashTreapMap<@Treapable K, V>(
this as? HashTreapMap<K, V>
?: (this as? PersistentMap.Builder<K, V>)?.build() as? HashTreapMap<K, V>

override fun singleOrNull() = MapEntry(key, value).takeIf { next == null && left == null && right == null }
override fun arbitraryOrNull(): Map.Entry<K, V>? = MapEntry(key, value)

override fun getShallowMerger(
Expand Down Expand Up @@ -358,6 +359,17 @@ internal class HashTreapMap<@Treapable K, V>(
forEachPair { (k, v) -> action(MapEntry(k, v)) }
right?.forEachEntry(action)
}

private fun treapSetFromKeys(): HashTreapSet<K> =
HashTreapSet(treapKey, next?.toKeyList(), left?.treapSetFromKeys(), right?.treapSetFromKeys())

inner class KeySet : AbstractKeySet<K, HashTreapSet<K>>() {
override val map get() = this@HashTreapMap
override val keys = lazy { treapSetFromKeys() }
override fun hashCode() = super.hashCode() // avoids treapability warning
}

override val keys get() = KeySet()
}

internal interface KeyValuePairList<K, V> {
Expand All @@ -367,6 +379,8 @@ internal interface KeyValuePairList<K, V> {
operator fun component1() = key
operator fun component2() = value

fun toKeyList(): ElementList.More<K> = ElementList.More(key, next?.toKeyList())

class More<K, V>(
override val key: K,
override val value: V,
Expand Down
10 changes: 8 additions & 2 deletions collect/src/main/kotlin/com/certora/collect/HashTreapSet.kt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ internal class HashTreapSet<@Treapable E>(
override fun Iterable<E>.toTreapSetOrNull(): HashTreapSet<E>? =
(this as? HashTreapSet<E>)
?: (this as? TreapSet.Builder<E>)?.build() as? HashTreapSet<E>
?: (this as? HashTreapMap<E, *>.KeySet)?.keys?.value

private inline fun ElementList<E>?.forEachNodeElement(action: (E) -> Unit) {
var current = this
Expand Down Expand Up @@ -228,8 +229,13 @@ internal class HashTreapSet<@Treapable E>(
}
}.iterator()

override fun shallowGetSingleElement(): E? = element.takeIf { next == null }

override fun singleOrNull(): E? = element.takeIf { next == null && left == null && right == null }
override fun single(): E {
if (next != null || left != null || right != null) {
throw IllegalArgumentException("Set contains more than one element")
}
return element
}
override fun arbitraryOrNull(): E? = element

override fun <R : Any> shallowMapReduce(map: (E) -> R, reduce: (R, R) -> R): R {
Expand Down
12 changes: 12 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ internal class SortedTreapMap<@Treapable K, V>(
this as? SortedTreapMap<K, V>
?: (this as? PersistentMap.Builder<K, V>)?.build() as? SortedTreapMap<K, V>

override fun singleOrNull(): Map.Entry<K, V>? = MapEntry(key, value).takeIf { left == null && right == null }
override fun arbitraryOrNull(): Map.Entry<K, V>? = MapEntry(key, value)

override fun getShallowUnionMerger(
Expand Down Expand Up @@ -170,4 +171,15 @@ internal class SortedTreapMap<@Treapable K, V>(
action(this.asEntry())
right?.forEachEntry(action)
}

private fun treapSetFromKeys(): SortedTreapSet<K> =
SortedTreapSet(treapKey, left?.treapSetFromKeys(), right?.treapSetFromKeys())

inner class KeySet : AbstractKeySet<K, SortedTreapSet<K>>() {
override val map get() = this@SortedTreapMap
override val keys = lazy { treapSetFromKeys() }
override fun hashCode() = super.hashCode() // avoids treapability warning
}

override val keys get() = KeySet()
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ internal class SortedTreapSet<@Treapable E>(
override fun Iterable<E>.toTreapSetOrNull(): SortedTreapSet<E>? =
(this as? SortedTreapSet<E>)
?: (this as? PersistentSet.Builder<E>)?.build() as? SortedTreapSet<E>
?: (this as? SortedTreapMap<E, *>.KeySet)?.keys?.value

override val self get() = this
override fun iterator(): Iterator<E> = this.asTreapSequence().map { it.treapKey }.iterator()
Expand All @@ -49,7 +50,13 @@ internal class SortedTreapSet<@Treapable E>(
override fun shallowRemove(element: E): SortedTreapSet<E>? = null
override fun shallowRemoveAll(predicate: (E) -> Boolean): SortedTreapSet<E>? = this.takeIf { !predicate(treapKey) }
override fun shallowComputeHashCode(): Int = treapKey.hashCode()
override fun shallowGetSingleElement(): E = treapKey
override fun singleOrNull(): E? = treapKey.takeIf { left == null && right == null }
override fun single(): E {
if (left != null || right != null) {
throw IllegalArgumentException("Set contains more than one element")
}
return treapKey
}
override fun arbitraryOrNull(): E? = treapKey
override fun shallowForEach(action: (element: E) -> Unit): Unit { action(treapKey) }
override fun <R : Any> shallowMapReduce(map: (E) -> R, reduce: (R, R) -> R): R = map(treapKey)
Expand Down
11 changes: 11 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/TreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ public sealed interface TreapMap<K, V> : PersistentMap<K, V> {
override fun remove(key: K, value: @UnsafeVariance V): TreapMap<K, V>
override fun putAll(m: Map<out K, @UnsafeVariance V>): TreapMap<K, V>
override fun clear(): TreapMap<K, V>
override val keys: TreapSet<K>

/**
A [PersistentMap.Builder] that produces a [TreapMap].
Expand All @@ -23,6 +24,16 @@ public sealed interface TreapMap<K, V> : PersistentMap<K, V> {
@Suppress("Treapability")
override fun builder(): Builder<K, @UnsafeVariance V> = TreapMapBuilder(this)

/**
If this map contains exactly one entry, returns that entry. Otherwise, throws.
*/
public fun single(): Map.Entry<K, V>

/**
If this map contains exactly one entry, returns that entry. Otherwise, returns null
*/
public fun singleOrNull(): Map.Entry<K, V>?

/**
Returns an arbitrary entry from the map, or null if the map is empty.
*/
Expand Down
2 changes: 1 addition & 1 deletion collect/src/main/kotlin/com/certora/collect/TreapSet.kt
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public sealed interface TreapSet<out T> : PersistentSet<T> {
public fun containsAny(predicate: (T) -> Boolean): Boolean

/**
If this set contains exactly one element, returns that element. Otherwise, throws [NoSuchElementException].
If this set contains exactly one element, returns that element. Otherwise, throws.
*/
public fun single(): T

Expand Down

0 comments on commit e2d8c8e

Please sign in to comment.