diff --git a/stdlib/src/collections/linked_list.mojo b/stdlib/src/collections/linked_list.mojo index 515b6404a0..2194745022 100644 --- a/stdlib/src/collections/linked_list.mojo +++ b/stdlib/src/collections/linked_list.mojo @@ -17,7 +17,7 @@ from collections._index_normalization import normalize_index trait WritableCollectionElement(CollectionElement, Writable): - """A trait that combines CollectionElement and Writable traits. + """A trait that combines the CollectionElement and Writable traits. This trait requires types to implement both CollectionElement and Writable interfaces, allowing them to be used in collections and written to output. @@ -26,26 +26,44 @@ trait WritableCollectionElement(CollectionElement, Writable): pass +trait EqualityComparableWritableCollectionElement( + WritableCollectionElement, EqualityComparable +): + """A trait that combines the CollectionElement, Writable and + EqualityComparable traits. + + This trait requires types to implement CollectionElement, Writable and + EqualityComparable interfaces, allowing them to be used in collections, + compared, and written to output. + """ + + pass + + @value -struct Node[ElementType: WritableCollectionElement]: +struct Node[ + ElementType: CollectionElement, +]: """A node in a linked list data structure. Parameters: ElementType: The type of element stored in the node. """ + alias NodePointer = UnsafePointer[Self] + var value: ElementType """The value stored in this node.""" - var prev: UnsafePointer[Node[ElementType]] + var prev: Self.NodePointer """The previous node in the list.""" - var next: UnsafePointer[Node[ElementType]] + var next: Self.NodePointer """The next node in the list.""" fn __init__( out self, owned value: ElementType, - prev: Optional[UnsafePointer[Node[ElementType]]], - next: Optional[UnsafePointer[Node[ElementType]]], + prev: Optional[Self.NodePointer], + next: Optional[Self.NodePointer], ): """Initialize a new Node with the given value and optional prev/next pointers. @@ -59,19 +77,29 @@ struct Node[ElementType: WritableCollectionElement]: self.prev = prev.value() if prev else __type_of(self.prev)() self.next = next.value() if next else __type_of(self.next)() - fn __str__(self) -> String: + fn __str__[ + ElementType: WritableCollectionElement + ](self: Node[ElementType]) -> String: """Convert this node's value to a string representation. + Parameters: + ElementType: Used to conditionally enable this function if + `ElementType` is `Writable`. + Returns: String representation of the node's value. """ - return String.write(self) + return String.write(self.value) @no_inline - fn write_to[W: Writer](self, mut writer: W): + fn write_to[ + ElementType: WritableCollectionElement, W: Writer + ](self: Node[ElementType], mut writer: W): """Write this node's value to the given writer. Parameters: + ElementType: Used to conditionally enable this function if + `ElementType` is `Writable`. W: The type of writer to write the value to. Args: @@ -80,7 +108,9 @@ struct Node[ElementType: WritableCollectionElement]: writer.write(self.value) -struct LinkedList[ElementType: WritableCollectionElement]: +struct LinkedList[ + ElementType: CollectionElement, +]: """A doubly-linked list implementation. A doubly-linked list is a data structure where each element points to both @@ -89,12 +119,14 @@ struct LinkedList[ElementType: WritableCollectionElement]: Parameters: ElementType: The type of elements stored in the list. Must implement - WritableCollectionElement. + CollectionElement. """ - var _head: UnsafePointer[Node[ElementType]] + alias NodePointer = UnsafePointer[Node[ElementType]] + + var _head: Self.NodePointer """The first node in the list.""" - var _tail: UnsafePointer[Node[ElementType]] + var _tail: Self.NodePointer """The last node in the list.""" var _size: Int """The number of elements in the list.""" @@ -114,30 +146,44 @@ struct LinkedList[ElementType: WritableCollectionElement]: self = Self(elements=elements^) fn __init__(out self, *, owned elements: VariadicListMem[ElementType, _]): - """Initialize a linked list with the given elements. + """ + Construct a list from a `VariadicListMem`. Args: - elements: Variable number of elements to initialize the list with. + elements: The elements to add to the list. """ self = Self() - for elem in elements: - self.append(elem[]) + var length = len(elements) + + for i in range(length): + var src = UnsafePointer.address_of(elements[i]) + var node = Self.NodePointer.alloc(1) + var dst = UnsafePointer.address_of(node[].value) + src.move_pointee_into(dst) + node[].next = Self.NodePointer() + node[].prev = self._tail + if not self._tail: + self._head = node + self._tail = node + else: + self._tail[].next = node + self._tail = node # Do not destroy the elements when their backing storage goes away. __mlir_op.`lit.ownership.mark_destroyed`( __get_mvalue_as_litref(elements) ) + self._size = length + fn __copyinit__(mut self, read other: Self): """Initialize this list as a copy of another list. Args: other: The list to copy from. """ - self._head = other._head - self._tail = other._tail - self._size = other._size + self = other.copy() fn __moveinit__(mut self, owned other: Self): """Initialize this list by moving elements from another list. @@ -167,10 +213,12 @@ struct LinkedList[ElementType: WritableCollectionElement]: Args: value: The value to append. """ - var node = Node[ElementType](value^, self._tail, None) - var addr = UnsafePointer[__type_of(node)].alloc(1) - addr.init_pointee_move(node) - if self: + var addr = Self.NodePointer.alloc(1) + var value_ptr = UnsafePointer.address_of(addr[].value) + value_ptr.init_pointee_move(value^) + addr[].prev = self._tail + addr[].next = Self.NodePointer() + if self._tail: self._tail[].next = addr else: self._head = addr @@ -205,20 +253,138 @@ struct LinkedList[ElementType: WritableCollectionElement]: self._tail = self._head self._head = prev - fn pop(mut self) -> ElementType: - """Remove and return the first element of the list. + fn pop(mut self) raises -> ElementType: + """Remove and return the last element of the list. Returns: - The first element in the list. + The last element in the list. """ var elem = self._tail + if not elem: + raise "Pop on empty list." + var value = elem[].value self._tail = elem[].prev self._size -= 1 if self._size == 0: self._head = __type_of(self._head)() + else: + self._tail[].next = Self.NodePointer() + elem.free() return value^ + fn pop[I: Indexer](mut self, owned i: I) raises -> ElementType: + """ + Remove the ith element of the list, counting from the tail if + given a negative index. + + Parameters: + I: The type of index to use. + + Args: + i: The index of the element to get. + + Returns: + Ownership of the indicated element. + """ + var current = self._get_node_ptr(Int(i)) + + if not current: + raise "Invalid index for pop" + else: + var node = current[] + if node.prev: + node.prev[].next = node.next + else: + self._head = node.next + if node.next: + node.next[].prev = node.prev + else: + self._tail = node.prev + + var data = node.value^ + + # Aside from T, destructor is trivial + __mlir_op.`lit.ownership.mark_destroyed`( + __get_mvalue_as_litref(node) + ) + current.free() + self._size -= 1 + return data^ + + fn pop_if_present(mut self) -> Optional[ElementType]: + """Removes the head of the list and returns it, if it exists. + + Returns: + The head of the list, if it was present. + """ + var elem = self._tail + if not elem: + return Optional[ElementType]() + var value = elem[].value + self._tail = elem[].prev + self._size -= 1 + if self._size == 0: + self._head = __type_of(self._head)() + else: + self._tail[].next = Self.NodePointer() + elem.free() + return value^ + + fn pop_if_present[ + I: Indexer + ](mut self, owned i: I) -> Optional[ElementType]: + """ + Remove the ith element of the list, counting from the tail if + given a negative index. + + Parameters: + I: The type of index to use. + + Args: + i: The index of the element to get. + + Returns: + The element, if it was found. + """ + var current = self._get_node_ptr(Int(i)) + + if not current: + return Optional[ElementType]() + else: + var node = current[] + if node.prev: + node.prev[].next = node.next + else: + self._head = node.next + if node.next: + node.next[].prev = node.prev + else: + self._tail = node.prev + + var data = node.value^ + + # Aside from T, destructor is trivial + __mlir_op.`lit.ownership.mark_destroyed`( + __get_mvalue_as_litref(node) + ) + current.free() + self._size -= 1 + return Optional[ElementType](data^) + + fn clear(mut self): + """Removes all elements from the list.""" + var current = self._head + while current: + var old = current + current = current[].next + old.destroy_pointee() + old.free() + + self._head = Self.NodePointer() + self._tail = Self.NodePointer() + self._size = 0 + fn copy(self) -> Self: """Create a deep copy of the list. @@ -232,6 +398,180 @@ struct LinkedList[ElementType: WritableCollectionElement]: curr = curr[].next return new^ + fn insert(mut self, owned idx: Int, owned elem: ElementType) raises: + """ + Insert an element `elem` into the list at index `idx`. + + Args: + idx: The index to insert `elem` at. + elem: The item to insert into the list. + """ + var i = max(0, index(idx) if idx >= 0 else index(idx) + len(self)) + + if i == 0: + var node = Self.NodePointer.alloc(1) + if not node: + raise "OOM" + node.init_pointee_move( + Node[ElementType](elem^, Self.NodePointer(), Self.NodePointer()) + ) + + if self._head: + node[].next = self._head + self._head[].prev = node + + self._head = node + + if not self._tail: + self._tail = node + + self._size += 1 + return + + i -= 1 + + var current = self._get_node_ptr(i) + if current: + var next = current[].next + var node = Self.NodePointer.alloc(1) + if not node: + raise "OOM" + var data = UnsafePointer.address_of(node[].value) + data[] = elem^ + node[].next = next + node[].prev = current + if next: + next[].prev = node + current[].next = node + if node[].next == Self.NodePointer(): + self._tail = node + if node[].prev == Self.NodePointer(): + self._head = node + self._size += 1 + else: + raise "index out of bounds" + + fn extend(mut self, owned other: Self): + """ + Extends the list with another. + O(1) time complexity. + + Args: + other: The list to append to this one. + """ + if self._tail: + self._tail[].next = other._head + if other._head: + other._head[].prev = self._tail + if other._tail: + self._tail = other._tail + + self._size += other._size + else: + self._head = other._head + self._tail = other._tail + self._size = other._size + + other._head = Self.NodePointer() + other._tail = Self.NodePointer() + + fn count[ + ElementType: EqualityComparableCollectionElement + ](self: LinkedList[ElementType], read elem: ElementType) -> UInt: + """ + Count the occurrences of `elem` in the list. + + Parameters: + ElementType: The list element type, used to conditionally enable the function. + + Args: + elem: The element to search for. + + Returns: + The number of occurrences of `elem` in the list. + """ + var current = self._head + var count = 0 + while current: + if current[].value == elem: + count += 1 + + current = current[].next + + return count + + fn __contains__[ + ElementType: EqualityComparableCollectionElement, // + ](self: LinkedList[ElementType], value: ElementType) -> Bool: + """ + Checks if the list contains `value`. + + Parameters: + ElementType: The list element type, used to conditionally enable the function. + + Args: + value: The value to search for in the list. + + Returns: + Whether the list contains `value`. + """ + var current = self._head + while current: + if current[].value == value: + return True + current = current[].next + + return False + + fn __eq__[ + ElementType: EqualityComparableCollectionElement, // + ]( + read self: LinkedList[ElementType], read other: LinkedList[ElementType] + ) -> Bool: + """ + Checks if the two lists are equal. + + Parameters: + ElementType: The list element type, used to conditionally enable the function. + + Args: + other: The list to compare to. + + Returns: + Whether the lists are equal. + """ + if self._size != other._size: + return False + + var self_cursor = self._head + var other_cursor = other._head + + while self_cursor: + if self_cursor[].value != other_cursor[].value: + return False + + self_cursor = self_cursor[].next + other_cursor = other_cursor[].next + + return True + + fn __ne__[ + ElementType: EqualityComparableCollectionElement, // + ](self: LinkedList[ElementType], other: LinkedList[ElementType]) -> Bool: + """ + Checks if the two lists are not equal. + + Parameters: + ElementType: The list element type, used to conditionally enable the function. + + Args: + other: The list to compare to. + + Returns: + Whether the lists are not equal. + """ + return not (self == other) + fn _get_node_ptr(ref self, index: Int) -> UnsafePointer[Node[ElementType]]: """Get a pointer to the node at the specified index. @@ -297,17 +637,31 @@ struct LinkedList[ElementType: WritableCollectionElement]: """ return len(self) != 0 - fn __str__(self) -> String: + fn __str__[ + ElementType: WritableCollectionElement + ](self: LinkedList[ElementType]) -> String: """Convert the list to its string representation. + Parameters: + ElementType: Used to conditionally enable this function when + `ElementType` is `Writable`. + Returns: String representation of the list. """ - return String.write(self) + var writer = String() + self._write(writer) + return writer - fn __repr__(self) -> String: + fn __repr__[ + ElementType: WritableCollectionElement + ](self: LinkedList[ElementType]) -> String: """Convert the list to its string representation. + Parameters: + ElementType: Used to conditionally enable this function when + `ElementType` is `Writable`. + Returns: String representation of the list. """ @@ -315,11 +669,15 @@ struct LinkedList[ElementType: WritableCollectionElement]: self._write(writer, prefix="LinkedList(", suffix=")") return writer - fn write_to[W: Writer](self, mut writer: W): + fn write_to[ + W: Writer, ElementType: WritableCollectionElement + ](self: LinkedList[ElementType], mut writer: W): """Write the list to the given writer. Parameters: W: The type of writer to write the list to. + ElementType: Used to conditionally enable this function when + `ElementType` is `Writable`. Args: writer: The writer to write the list to. @@ -328,8 +686,14 @@ struct LinkedList[ElementType: WritableCollectionElement]: @no_inline fn _write[ - W: Writer - ](self, mut writer: W, *, prefix: String = "[", suffix: String = "]"): + W: Writer, ElementType: WritableCollectionElement + ]( + self: LinkedList[ElementType], + mut writer: W, + *, + prefix: String = "[", + suffix: String = "]", + ): if not self: return writer.write(prefix, suffix) @@ -338,6 +702,6 @@ struct LinkedList[ElementType: WritableCollectionElement]: for i in range(len(self)): if i: writer.write(", ") - writer.write(curr[]) + writer.write(curr[].value) curr = curr[].next writer.write(suffix) diff --git a/stdlib/test/collections/test_linked_list.mojo b/stdlib/test/collections/test_linked_list.mojo index 1375def212..42fd2a533c 100644 --- a/stdlib/test/collections/test_linked_list.mojo +++ b/stdlib/test/collections/test_linked_list.mojo @@ -12,8 +12,9 @@ # ===----------------------------------------------------------------------=== # # RUN: %mojo-no-debug %s -from collections import LinkedList -from testing import assert_equal +from collections import LinkedList, Optional +from testing import assert_equal, assert_raises, assert_true, assert_false +from test_utils import CopyCounter, MoveCounter def test_construction(): @@ -101,12 +102,487 @@ def test_setitem(): def test_str(): var l1 = LinkedList[Int](1, 2, 3) - assert_equal(String(l1), "[1, 2, 3]") + assert_equal(l1.__str__(), "[1, 2, 3]") def test_repr(): var l1 = LinkedList[Int](1, 2, 3) - assert_equal(repr(l1), "LinkedList(1, 2, 3)") + assert_equal(l1.__repr__(), "LinkedList(1, 2, 3)") + + +def test_pop_on_empty_list(): + with assert_raises(): + var ll = LinkedList[Int]() + _ = ll.pop() + + +def test_optional_pop_on_empty_linked_list(): + var ll = LinkedList[Int]() + var result = ll.pop_if_present() + assert_false(Bool(result)) + + +def test_list(): + var list = LinkedList[Int]() + + for i in range(5): + list.append(i) + + assert_equal(5, len(list)) + assert_equal(0, list[0]) + assert_equal(1, list[1]) + assert_equal(2, list[2]) + assert_equal(3, list[3]) + assert_equal(4, list[4]) + + assert_equal(0, list[-5]) + assert_equal(3, list[-2]) + assert_equal(4, list[-1]) + + list[2] = -2 + assert_equal(-2, list[2]) + + list[-5] = 5 + assert_equal(5, list[-5]) + list[-2] = 3 + assert_equal(3, list[-2]) + list[-1] = 7 + assert_equal(7, list[-1]) + + +def test_list_clear(): + var list = LinkedList[Int](1, 2, 3) + assert_equal(len(list), 3) + list.clear() + + assert_equal(len(list), 0) + + +def test_list_to_bool_conversion(): + assert_false(LinkedList[String]()) + assert_true(LinkedList[String]("a")) + assert_true(LinkedList[String]("", "a")) + assert_true(LinkedList[String]("")) + + +def test_list_pop(): + var list = LinkedList[Int]() + # Test pop with index + for i in range(6): + list.append(i) + + assert_equal(6, len(list)) + + # try popping from index 3 for 3 times + for i in range(3, 6): + assert_equal[Int](i, list.pop(3)) + + # list should have 3 elements now + assert_equal(3, len(list)) + assert_equal(0, list[0]) + assert_equal(1, list[1]) + assert_equal(2, list[2]) + + # Test pop with negative index + for i in range(0, 2): + var popped: Int = list.pop(-len(list)) + assert_equal(i, popped) + + # test default index as well + assert_equal(2, list.pop()) + list.append(2) + assert_equal(2, list.pop()) + + # list should be empty now + assert_equal(0, len(list)) + + +def test_list_variadic_constructor(): + var l = LinkedList[Int](2, 4, 6) + assert_equal(3, len(l)) + assert_equal(2, l[0]) + assert_equal(4, l[1]) + assert_equal(6, l[2]) + + l.append(8) + assert_equal(4, len(l)) + assert_equal(8, l[3]) + + # + # Test variadic construct copying behavior + # + + var l2 = LinkedList[CopyCounter]( + CopyCounter(), CopyCounter(), CopyCounter() + ) + + assert_equal(len(l2), 3) + assert_equal(l2[0].copy_count, 0) + assert_equal(l2[1].copy_count, 0) + assert_equal(l2[2].copy_count, 0) + + +def test_list_reverse(): + # + # Test reversing the list [] + # + + var vec = LinkedList[Int]() + + assert_equal(len(vec), 0) + + vec.reverse() + + assert_equal(len(vec), 0) + + # + # Test reversing the list [123] + # + + vec = LinkedList[Int]() + + vec.append(123) + + assert_equal(len(vec), 1) + assert_equal(vec[0], 123) + + vec.reverse() + + assert_equal(len(vec), 1) + assert_equal(vec[0], 123) + + # + # Test reversing the list ["one", "two", "three"] + # + + var vec2 = LinkedList[String]("one", "two", "three") + + assert_equal(len(vec2), 3) + assert_equal(vec2[0], "one") + assert_equal(vec2[1], "two") + assert_equal(vec2[2], "three") + + vec2.reverse() + + assert_equal(len(vec2), 3) + assert_equal(vec2[0], "three") + assert_equal(vec2[1], "two") + assert_equal(vec2[2], "one") + + # + # Test reversing the list [5, 10] + # + + vec = LinkedList[Int]() + vec.append(5) + vec.append(10) + + assert_equal(len(vec), 2) + assert_equal(vec[0], 5) + assert_equal(vec[1], 10) + + vec.reverse() + + assert_equal(len(vec), 2) + assert_equal(vec[0], 10) + assert_equal(vec[1], 5) + + +def test_list_insert(): + # + # Test the list [1, 2, 3] created with insert + # + + var v1 = LinkedList[Int]() + v1.insert(len(v1), 1) + v1.insert(len(v1), 3) + v1.insert(1, 2) + + assert_equal(len(v1), 3) + assert_equal(v1[0], 1) + assert_equal(v1[1], 2) + assert_equal(v1[2], 3) + + print(v1.__str__()) + + # + # Test the list [1, 2, 3, 4, 5] created with negative and positive index + # + + var v2 = LinkedList[Int]() + v2.insert(-1729, 2) + v2.insert(len(v2), 3) + v2.insert(len(v2), 5) + v2.insert(-1, 4) + v2.insert(-len(v2), 1) + print(v2.__str__()) + + assert_equal(len(v2), 5) + assert_equal(v2[0], 1) + assert_equal(v2[1], 2) + assert_equal(v2[2], 3) + assert_equal(v2[3], 4) + assert_equal(v2[4], 5) + + # + # Test the list [1, 2, 3, 4] created with negative index + # + + var v3 = LinkedList[Int]() + v3.insert(-11, 4) + v3.insert(-13, 3) + v3.insert(-17, 2) + v3.insert(-19, 1) + + assert_equal(len(v3), 4) + assert_equal(v3[0], 1) + assert_equal(v3[1], 2) + assert_equal(v3[2], 3) + assert_equal(v3[3], 4) + + # + # Test the list [1, 2, 3, 4, 5, 6, 7, 8] created with insert + # + + var v4 = LinkedList[Int]() + for i in range(4): + v4.insert(0, 4 - i) + v4.insert(len(v4), 4 + i + 1) + + for i in range(len(v4)): + assert_equal(v4[i], i + 1) + + +def test_list_extend_non_trivial(): + # Tests three things: + # - extend() for non-plain-old-data types + # - extend() with mixed-length self and other lists + # - extend() using optimal number of __moveinit__() calls + + # Preallocate with enough capacity to avoid reallocation making the + # move count checks below flaky. + var v1 = LinkedList[MoveCounter[String]]() + v1.append(MoveCounter[String]("Hello")) + v1.append(MoveCounter[String]("World")) + + var v2 = LinkedList[MoveCounter[String]]() + v2.append(MoveCounter[String]("Foo")) + v2.append(MoveCounter[String]("Bar")) + v2.append(MoveCounter[String]("Baz")) + + v1.extend(v2^) + + assert_equal(len(v1), 5) + assert_equal(v1[0].value, "Hello") + assert_equal(v1[1].value, "World") + assert_equal(v1[2].value, "Foo") + assert_equal(v1[3].value, "Bar") + assert_equal(v1[4].value, "Baz") + + assert_equal(v1[0].move_count, 1) + assert_equal(v1[1].move_count, 1) + assert_equal(v1[2].move_count, 1) + assert_equal(v1[3].move_count, 1) + assert_equal(v1[4].move_count, 1) + + +def test_2d_dynamic_list(): + var list = LinkedList[LinkedList[Int]]() + + for i in range(2): + var v = LinkedList[Int]() + for j in range(3): + v.append(i + j) + list.append(v) + + assert_equal(0, list[0][0]) + assert_equal(1, list[0][1]) + assert_equal(2, list[0][2]) + assert_equal(1, list[1][0]) + assert_equal(2, list[1][1]) + assert_equal(3, list[1][2]) + + assert_equal(2, len(list)) + + assert_equal(3, len(list[0])) + + list[0].clear() + assert_equal(0, len(list[0])) + + list.clear() + assert_equal(0, len(list)) + + +def test_list_explicit_copy(): + var list = LinkedList[CopyCounter]() + list.append(CopyCounter()) + var list_copy = list.copy() + assert_equal(0, list[0].copy_count) + assert_equal(1, list_copy[0].copy_count) + + var l2 = LinkedList[Int]() + for i in range(10): + l2.append(i) + + var l2_copy = l2.copy() + assert_equal(len(l2), len(l2_copy)) + for i in range(len(l2)): + assert_equal(l2[i], l2_copy[i]) + + +@value +struct CopyCountedStruct(CollectionElement): + var counter: CopyCounter + var value: String + + fn __init__(out self, *, other: Self): + self.counter = other.counter.copy() + self.value = other.value.copy() + + @implicit + fn __init__(out self, value: String): + self.counter = CopyCounter() + self.value = value + + +def test_no_extra_copies_with_sugared_set_by_field(): + var list = LinkedList[LinkedList[CopyCountedStruct]]() + var child_list = LinkedList[CopyCountedStruct]() + child_list.append(CopyCountedStruct("Hello")) + child_list.append(CopyCountedStruct("World")) + + # No copies here. Constructing with LinkedList[CopyCountedStruct](CopyCountedStruct("Hello")) is a copy. + assert_equal(0, child_list[0].counter.copy_count) + assert_equal(0, child_list[1].counter.copy_count) + + list.append(child_list^) + + assert_equal(0, list[0][0].counter.copy_count) + assert_equal(0, list[0][1].counter.copy_count) + + # list[0][1] makes a copy for reasons I cannot determine + list.__getitem__(0).__getitem__(1).value = "Mojo" + + assert_equal(0, list[0][0].counter.copy_count) + assert_equal(0, list[0][1].counter.copy_count) + + assert_equal("Mojo", list[0][1].value) + + assert_equal(0, list[0][0].counter.copy_count) + assert_equal(0, list[0][1].counter.copy_count) + + +def test_list_boolable(): + assert_true(LinkedList[Int](1)) + assert_false(LinkedList[Int]()) + + +def test_list_count(): + var list = LinkedList[Int](1, 2, 3, 2, 5, 6, 7, 8, 9, 10) + assert_equal(1, list.count(1)) + assert_equal(2, list.count(2)) + assert_equal(0, list.count(4)) + + var list2 = LinkedList[Int]() + assert_equal(0, list2.count(1)) + + +def test_list_contains(): + var x = LinkedList[Int](1, 2, 3) + assert_false(0 in x) + assert_true(1 in x) + assert_false(4 in x) + + # TODO: implement LinkedList.__eq__ for Self[ComparableCollectionElement] + # var y = LinkedList[LinkedList[Int]]() + # y.append(LinkedList(1,2)) + # assert_equal(LinkedList(1,2) in y,True) + # assert_equal(LinkedList(0,1) in y,False) + + +def test_list_eq_ne(): + var l1 = LinkedList[Int](1, 2, 3) + var l2 = LinkedList[Int](1, 2, 3) + assert_true(l1 == l2) + assert_false(l1 != l2) + + var l3 = LinkedList[Int](1, 2, 3, 4) + assert_false(l1 == l3) + assert_true(l1 != l3) + + var l4 = LinkedList[Int]() + var l5 = LinkedList[Int]() + assert_true(l4 == l5) + assert_true(l1 != l4) + + var l6 = LinkedList[String]("a", "b", "c") + var l7 = LinkedList[String]("a", "b", "c") + var l8 = LinkedList[String]("a", "b") + assert_true(l6 == l7) + assert_false(l6 != l7) + assert_false(l6 == l8) + + +def test_indexing(): + var l = LinkedList[Int](1, 2, 3) + assert_equal(l[Int(1)], 2) + assert_equal(l[False], 1) + assert_equal(l[True], 2) + assert_equal(l[2], 3) + + +# ===-------------------------------------------------------------------===# +# LinkedList dtor tests +# ===-------------------------------------------------------------------===# +var g_dtor_count: Int = 0 + + +struct DtorCounter(CollectionElement, Writable): + # NOTE: payload is required because LinkedList does not support zero sized structs. + var payload: Int + + fn __init__(out self): + self.payload = 0 + + fn __init__(out self, *, other: Self): + self.payload = other.payload + + fn __copyinit__(out self, existing: Self, /): + self.payload = existing.payload + + fn __moveinit__(out self, owned existing: Self, /): + self.payload = existing.payload + existing.payload = 0 + + fn __del__(owned self): + g_dtor_count += 1 + + fn write_to[W: Writer](self, mut writer: W): + writer.write("DtorCounter(") + writer.write(String(g_dtor_count)) + writer.write(")") + + +def inner_test_list_dtor(): + # explicitly reset global counter + g_dtor_count = 0 + + var l = LinkedList[DtorCounter]() + assert_equal(g_dtor_count, 0) + + l.append(DtorCounter()) + assert_equal(g_dtor_count, 0) + + l^.__del__() + assert_equal(g_dtor_count, 1) + + +def test_list_dtor(): + # call another function to force the destruction of the list + inner_test_list_dtor() + + # verify we still only ran the destructor once + assert_equal(g_dtor_count, 1) def main(): @@ -120,3 +596,22 @@ def main(): test_setitem() test_str() test_repr() + test_pop_on_empty_list() + test_optional_pop_on_empty_linked_list() + test_list() + test_list_clear() + test_list_to_bool_conversion() + test_list_pop() + test_list_variadic_constructor() + test_list_reverse() + test_list_extend_non_trivial() + test_list_explicit_copy() + test_no_extra_copies_with_sugared_set_by_field() + test_2d_dynamic_list() + test_list_boolable() + test_list_count() + test_list_contains() + test_indexing() + test_list_dtor() + test_list_insert() + test_list_eq_ne() diff --git a/stdlib/test/test_utils/types.mojo b/stdlib/test/test_utils/types.mojo index 01870acebc..0610304a36 100644 --- a/stdlib/test/test_utils/types.mojo +++ b/stdlib/test/test_utils/types.mojo @@ -88,7 +88,7 @@ struct ImplicitCopyOnly(Copyable): # ===----------------------------------------------------------------------=== # -struct CopyCounter(CollectionElement, ExplicitlyCopyable): +struct CopyCounter(CollectionElement, ExplicitlyCopyable, Writable): """Counts the number of copies performed on a value.""" var copy_count: Int @@ -108,6 +108,11 @@ struct CopyCounter(CollectionElement, ExplicitlyCopyable): fn copy(self) -> Self: return self + fn write_to[W: Writer](self, mut writer: W): + writer.write("CopyCounter(") + writer.write(String(self.copy_count)) + writer.write(")") + # ===----------------------------------------------------------------------=== # # MoveCounter @@ -117,6 +122,7 @@ struct CopyCounter(CollectionElement, ExplicitlyCopyable): struct MoveCounter[T: CollectionElementNew]( CollectionElement, CollectionElementNew, + Writable, ): """Counts the number of moves performed on a value.""" @@ -155,6 +161,11 @@ struct MoveCounter[T: CollectionElementNew]( fn copy(self) -> Self: return self + fn write_to[W: Writer](self, mut writer: W): + writer.write("MoveCounter(") + writer.write(String(self.move_count)) + writer.write(")") + # ===----------------------------------------------------------------------=== # # ValueDestructorRecorder