From 889dfe2fdb6f604617f3cd4e2563666a39a2331f Mon Sep 17 00:00:00 2001 From: Eric Eilebrecht Date: Mon, 9 Dec 2024 15:56:04 -0800 Subject: [PATCH] Add `forEachEntry` and `arbitraryOrNull` (#17) - `forEachEntry`: Fast enumeration of map entries via a simple recursive walk - `arbitraryOrNull`: Quickly get a single arbitrary element/entry from a set/map. --- .../com/certora/collect/EmptyTreapMap.kt | 4 ++ .../com/certora/collect/EmptyTreapSet.kt | 1 + .../com/certora/collect/HashTreapMap.kt | 8 ++++ .../com/certora/collect/HashTreapSet.kt | 2 + .../com/certora/collect/SortedTreapMap.kt | 8 ++++ .../com/certora/collect/SortedTreapSet.kt | 1 + .../kotlin/com/certora/collect/TreapMap.kt | 7 +++ .../kotlin/com/certora/collect/TreapSet.kt | 5 +++ .../com/certora/collect/HashTreapMapTest.kt | 4 +- .../com/certora/collect/SortedTreapMapTest.kt | 4 +- .../com/certora/collect/TreapMapTest.kt | 45 ++++++++++++++++--- .../com/certora/collect/TreapSetTest.kt | 21 ++++++++- 12 files changed, 99 insertions(+), 11 deletions(-) diff --git a/collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt b/collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt index f022933..bdb0862 100644 --- a/collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt @@ -19,6 +19,10 @@ internal class EmptyTreapMap<@Treapable K, V> private constructor() : TreapMap = this override fun remove(key: K, value: V): TreapMap = this + override fun arbitraryOrNull(): Map.Entry? = null + + override fun forEachEntry(action: (Map.Entry) -> Unit): Unit {} + override fun updateValues( transform: (K, V) -> R? ): TreapMap = treapMapOf() diff --git a/collect/src/main/kotlin/com/certora/collect/EmptyTreapSet.kt b/collect/src/main/kotlin/com/certora/collect/EmptyTreapSet.kt index 4b880b9..609fa57 100644 --- a/collect/src/main/kotlin/com/certora/collect/EmptyTreapSet.kt +++ b/collect/src/main/kotlin/com/certora/collect/EmptyTreapSet.kt @@ -24,6 +24,7 @@ internal class EmptyTreapSet<@Treapable E> private constructor() : TreapSet, override fun retainAll(elements: Collection): TreapSet = this override fun single(): E = throw NoSuchElementException("Empty set.") override fun singleOrNull(): E? = null + override fun arbitraryOrNull(): E? = null override fun mapReduce(map: (E) -> R, reduce: (R, R) -> R): R? = null override fun parallelMapReduce(map: (E) -> R, reduce: (R, R) -> R, parallelThresholdLog2: Int): R? = null diff --git a/collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt b/collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt index abddbed..081dee2 100644 --- a/collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt @@ -31,6 +31,8 @@ internal class HashTreapMap<@Treapable K, V>( this as? HashTreapMap ?: (this as? PersistentMap.Builder)?.build() as? HashTreapMap + override fun arbitraryOrNull(): Map.Entry? = MapEntry(key, value) + override fun getShallowMerger(merger: (K, V?, V?) -> V?): (HashTreapMap?, HashTreapMap?) -> HashTreapMap? = { t1, t2 -> var newPairs: KeyValuePairList.More? = null t1?.forEachPair { (k, v1) -> @@ -300,6 +302,12 @@ internal class HashTreapMap<@Treapable K, V>( } return result!! } + + override fun forEachEntry(action: (Map.Entry) -> Unit) { + left?.forEachEntry(action) + forEachPair { (k, v) -> action(MapEntry(k, v)) } + right?.forEachEntry(action) + } } internal interface KeyValuePairList { diff --git a/collect/src/main/kotlin/com/certora/collect/HashTreapSet.kt b/collect/src/main/kotlin/com/certora/collect/HashTreapSet.kt index fd819f0..0f6c8d5 100644 --- a/collect/src/main/kotlin/com/certora/collect/HashTreapSet.kt +++ b/collect/src/main/kotlin/com/certora/collect/HashTreapSet.kt @@ -230,6 +230,8 @@ internal class HashTreapSet<@Treapable E>( override fun shallowGetSingleElement(): E? = element.takeIf { next == null } + override fun arbitraryOrNull(): E? = element + override fun shallowMapReduce(map: (E) -> R, reduce: (R, R) -> R): R { var result: R? = null forEachNodeElement { diff --git a/collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt b/collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt index ae13992..2d0cdac 100644 --- a/collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt @@ -31,6 +31,8 @@ internal class SortedTreapMap<@Treapable K, V>( this as? SortedTreapMap ?: (this as? PersistentMap.Builder)?.build() as? SortedTreapMap + override fun arbitraryOrNull(): Map.Entry? = MapEntry(key, value) + override fun getShallowMerger(merger: (K, V?, V?) -> V?): (SortedTreapMap?, SortedTreapMap?) -> SortedTreapMap? = { t1, t2 -> val k = (t1 ?: t2)!!.key val v1 = t1?.value @@ -141,4 +143,10 @@ internal class SortedTreapMap<@Treapable K, V>( fun lastEntry(): Map.Entry? = right?.lastEntry() ?: this.asEntry() override fun shallowMapReduce(map: (K, V) -> R, reduce: (R, R) -> R): R = map(key, value) + + override fun forEachEntry(action: (Map.Entry) -> Unit) { + left?.forEachEntry(action) + action(this.asEntry()) + right?.forEachEntry(action) + } } diff --git a/collect/src/main/kotlin/com/certora/collect/SortedTreapSet.kt b/collect/src/main/kotlin/com/certora/collect/SortedTreapSet.kt index d683b71..87e9f4b 100644 --- a/collect/src/main/kotlin/com/certora/collect/SortedTreapSet.kt +++ b/collect/src/main/kotlin/com/certora/collect/SortedTreapSet.kt @@ -50,6 +50,7 @@ internal class SortedTreapSet<@Treapable E>( override fun shallowRemoveAll(predicate: (E) -> Boolean): SortedTreapSet? = this.takeIf { !predicate(treapKey) } override fun shallowComputeHashCode(): Int = treapKey.hashCode() override fun shallowGetSingleElement(): E = treapKey + override fun arbitraryOrNull(): E? = treapKey override fun shallowForEach(action: (element: E) -> Unit): Unit { action(treapKey) } override fun shallowMapReduce(map: (E) -> R, reduce: (R, R) -> R): R = map(treapKey) } diff --git a/collect/src/main/kotlin/com/certora/collect/TreapMap.kt b/collect/src/main/kotlin/com/certora/collect/TreapMap.kt index c9adc55..6c5dc0e 100644 --- a/collect/src/main/kotlin/com/certora/collect/TreapMap.kt +++ b/collect/src/main/kotlin/com/certora/collect/TreapMap.kt @@ -23,6 +23,13 @@ public sealed interface TreapMap : PersistentMap { @Suppress("Treapability") override fun builder(): Builder = TreapMapBuilder(this) + /** + Returns an arbitrary entry from the map, or null if the map is empty. + */ + public fun arbitraryOrNull(): Map.Entry? + + public fun forEachEntry(action: (Map.Entry) -> Unit): Unit + public fun merge( m: Map, merger: (K, V?, V?) -> V? diff --git a/collect/src/main/kotlin/com/certora/collect/TreapSet.kt b/collect/src/main/kotlin/com/certora/collect/TreapSet.kt index 14e719a..66de30b 100644 --- a/collect/src/main/kotlin/com/certora/collect/TreapSet.kt +++ b/collect/src/main/kotlin/com/certora/collect/TreapSet.kt @@ -42,6 +42,11 @@ public sealed interface TreapSet : PersistentSet { */ public fun singleOrNull(): T? + /** + Returns an arbitrary element from the set, or null if the set is empty. + */ + public fun arbitraryOrNull(): T? + /** If this set contains an element that compares equal to the specified [element], returns that element instance. diff --git a/collect/src/test/kotlin/com/certora/collect/HashTreapMapTest.kt b/collect/src/test/kotlin/com/certora/collect/HashTreapMapTest.kt index edb962e..3a47035 100644 --- a/collect/src/test/kotlin/com/certora/collect/HashTreapMapTest.kt +++ b/collect/src/test/kotlin/com/certora/collect/HashTreapMapTest.kt @@ -6,9 +6,9 @@ import kotlinx.serialization.DeserializationStrategy /** Tests for [HashTreapMap]. */ class HashTreapMapTest: TreapMapTest() { override fun makeKey(value: Int, code: Int) = HashTestKey(value, code) - override fun makeMap(): MutableMap = treapMapOf().builder() + override fun makeMap(): TreapMap.Builder = treapMapOf().builder() override fun makeBaseline(): MutableMap = HashMap() - override fun makeMap(other: Map): MutableMap = makeMap().apply { putAll(other) } + override fun makeMap(other: Map): TreapMap.Builder = makeMap().apply { putAll(other) } override fun makeBaseline(other: Map): MutableMap = HashMap(other) override fun makeMapOfInts(): TreapMap = treapMapOf() override fun makeMapOfInts(other: Map) = makeMapOfInts().apply { putAll(other) } diff --git a/collect/src/test/kotlin/com/certora/collect/SortedTreapMapTest.kt b/collect/src/test/kotlin/com/certora/collect/SortedTreapMapTest.kt index 85d15e1..78cf278 100644 --- a/collect/src/test/kotlin/com/certora/collect/SortedTreapMapTest.kt +++ b/collect/src/test/kotlin/com/certora/collect/SortedTreapMapTest.kt @@ -10,9 +10,9 @@ import java.util.TreeMap class SortedTreapMapTest: TreapMapTest() { override fun makeKey(value: Int, code: Int) = ComparableTestKey(value, code) override val allowNullKeys = false - override fun makeMap(): MutableMap = treapMapOf().builder() as MutableMap + override fun makeMap(): TreapMap.Builder = treapMapOf().builder() as TreapMap.Builder override fun makeBaseline(): MutableMap = TreeMap() - override fun makeMap(other: Map): MutableMap = makeMap().apply { putAll(other) } + override fun makeMap(other: Map): TreapMap.Builder = makeMap().apply { putAll(other) } override fun makeBaseline(other: Map): MutableMap = TreeMap(other) override fun makeMapOfInts(): TreapMap = treapMapOf() as TreapMap override fun makeMapOfInts(other: Map) = makeMapOfInts().apply { putAll(other) } diff --git a/collect/src/test/kotlin/com/certora/collect/TreapMapTest.kt b/collect/src/test/kotlin/com/certora/collect/TreapMapTest.kt index 28b4bc2..3bece70 100644 --- a/collect/src/test/kotlin/com/certora/collect/TreapMapTest.kt +++ b/collect/src/test/kotlin/com/certora/collect/TreapMapTest.kt @@ -14,14 +14,14 @@ abstract class TreapMapTest { abstract fun makeKey(value: Int, code: Int = value.hashCode()): TestKey open val allowNullKeys = true - abstract fun makeMap(): MutableMap + abstract fun makeMap(): TreapMap.Builder abstract fun makeBaseline(): MutableMap - abstract fun makeMap(other: Map): MutableMap + abstract fun makeMap(other: Map): TreapMap.Builder abstract fun makeBaseline(other: Map): MutableMap open fun assertOrderedIteration(expected: Iterator<*>, actual: Iterator<*>) { } - fun assertVeryEqual(expected: Map<*,*>, actual: Map<*,*>) { + fun assertVeryEqual(expected: Map<*,*>, actual: TreapMap<*,*>) { assertEquals(expected, actual) assertTrue(actual.equals(expected)) assertEquals(expected.hashCode(), actual.hashCode()) @@ -41,9 +41,27 @@ abstract class TreapMapTest { val actualValues = actual.values assertEquals(expectedValues.size, actualValues.size) assertOrderedIteration(expectedValues.iterator(), actualValues.iterator()) + + val actualForEachEntries = mutableListOf>() + actual.forEach { + actualForEachEntries += it + } + assertOrderedIteration(expected.entries.iterator(), actualForEachEntries.iterator()) + + val actualForEachEntryEntries = mutableListOf>() + actual.forEachEntry { + actualForEachEntryEntries += it + } + assertOrderedIteration(expected.entries.iterator(), actualForEachEntryEntries.iterator()) + + assertEquals(actualForEachEntries, actualForEachEntryEntries) + } + + fun assertVeryEqual(expected: Map<*, *>, actual: TreapMap.Builder<*, *>) { + assertVeryEqual(expected, actual.build()) } - fun assertEqualMutation(baseline: MutableMap, map: MutableMap, action: MutableMap.() -> TResult) { + fun assertEqualMutation(baseline: MutableMap, map: TreapMap.Builder, action: MutableMap.() -> TResult) { assertEqualResult(baseline, map, action) assertVeryEqual(baseline, map) } @@ -350,7 +368,7 @@ abstract class TreapMapTest { @Suppress("UNCHECKED_CAST") val db = Json.decodeFromString(getBaseDeserializer()!!, bs) as Map @Suppress("UNCHECKED_CAST") - val dm = Json.decodeFromString(getDeserializer()!!, ms) as Map + val dm = Json.decodeFromString(getDeserializer()!!, ms) as TreapMap assertVeryEqual(db, dm) } @@ -432,4 +450,21 @@ abstract class TreapMapTest { testMapOf(1 to 2, 2 to 3).zip(testMapOf(1 to 3, 2 to 4)).toSet() ) } + + @Test + fun arbitraryOrNull() { + val m = makeMap() + assertNull(m.build().arbitraryOrNull()) + + m[makeKey(1, 1)] = 1 + assertEquals(1, m.build().arbitraryOrNull()!!.value) + + m[makeKey(2, 1)] = 2 + assertTrue(m.build().arbitraryOrNull()!!.value in 1..2) + + for (it in 3..100) { + m[makeKey(it)] = it + } + assertTrue(m.build().arbitraryOrNull()!!.value in 1..100) + } } diff --git a/collect/src/test/kotlin/com/certora/collect/TreapSetTest.kt b/collect/src/test/kotlin/com/certora/collect/TreapSetTest.kt index f3af578..da4dd9e 100644 --- a/collect/src/test/kotlin/com/certora/collect/TreapSetTest.kt +++ b/collect/src/test/kotlin/com/certora/collect/TreapSetTest.kt @@ -16,10 +16,10 @@ abstract class TreapSetTest { abstract fun makeKey(value: Int, code: Int = value.hashCode()): TestKey open val nullKeysAllowed: Boolean get() = true - fun makeSet(): MutableSet = treapSetOf().builder() + fun makeSet(): TreapSet.Builder = treapSetOf().builder() abstract fun makeBaseline(): MutableSet - fun makeSet(other: Collection): MutableSet = makeSet().also { it += other } + fun makeSet(other: Collection): TreapSet.Builder = makeSet().also { it += other } fun makeBaseline(other: Collection): MutableSet = makeSet().also { it += other } open fun assertOrderedIteration(expected: Iterator<*>, actual: Iterator<*>) {} @@ -411,5 +411,22 @@ abstract class TreapSetTest { assertVeryEqual(s, rs) assertVeryEqual(b, rs) } + + @Test + fun arbitraryOrNull() { + val s = makeSet() + assertNull(s.build().arbitraryOrNull()) + + s += makeKey(1, 1) + assertEquals(makeKey(1, 1), s.build().arbitraryOrNull()) + + s += makeKey(2, 1) + assertTrue(s.build().arbitraryOrNull() in (1..2).map { makeKey(it) }) + + for (it in 3..100) { + s += makeKey(it) + } + assertTrue(s.build().arbitraryOrNull() in (1..100).map { makeKey(it) }) + } }