Skip to content

Commit

Permalink
[stdlib] Address Int vs UInt issues in normalize_index
Browse files Browse the repository at this point in the history
Signed-off-by: Yinon Burgansky <yinonburgansky@gmail.com>
  • Loading branch information
yinonburgansky committed Jan 21, 2025
1 parent 3abe545 commit 08417d9
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 70 deletions.
75 changes: 55 additions & 20 deletions stdlib/src/collections/_index_normalization.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -18,57 +18,92 @@ from sys.intrinsics import _type_is_eq

@always_inline
fn normalize_index[
I: Indexer, ContainerType: Sized, //, container_name: StringLiteral
](idx: I, container: ContainerType) -> UInt:
IdxType: Indexer, //, container_name: StringLiteral
](idx: IdxType, length: UInt) -> UInt:
"""Normalize the given index value to a valid index value for the given container length.
If the provided value is negative, the `index + container_length` is returned.
Parameters:
I: A type that can be used as an index.
ContainerType: The type of the container. Must have a `__len__` method.
IdxType: A type that can be used as an index.
container_name: The name of the container. Used for the error message.
Args:
idx: The index value to normalize.
container: The container to normalize the index for.
length: The container length to normalize the index for.
Returns:
The normalized index value.
"""
debug_assert[assert_mode="safe", cpu_only=True](
len(container) > 0,
"indexing into a ",
container_name,
" that has 0 elements",
)

@parameter
if _type_is_eq[I, UInt]():
var i = rebind[UInt](idx)
if (
_type_is_eq[IdxType, UInt]()
or _type_is_eq[IdxType, UInt8]()
or _type_is_eq[IdxType, UInt16]()
or _type_is_eq[IdxType, UInt32]()
or _type_is_eq[IdxType, UInt64]()
):
var i = UInt(index(idx))
debug_assert[assert_mode="safe", cpu_only=True](
i < len(container),
i < length,
container_name,
" index out of bounds: ",
i,
" should be less than ",
len(container),
length,
)
return i
else:
# Optimize for the common case:
# Proper comparison between Int and UInt is slower and containers with
# more than Int.MAX elements are rare.
# Don't use "safe" since this is considered an overflow error.
debug_assert(
length <= UInt(Int.MAX),
"Overflow Error: ",
container_name,
" length is grater than Int.MAX (",
length,
"). Consider indexing with the UInt type.",
)
var i = Int(idx)
# TODO: Consider a way to construct the error message after the assert has failed
# something like "Indexing into an empty container" if length == 0 else "..."
debug_assert[assert_mode="safe", cpu_only=True](
-len(container) <= i < len(container),
-Int(length) <= i < Int(length),
container_name,
" has length: ",
len(container),
length,
" index out of bounds: ",
i,
" should be between ",
-len(container),
-Int(length),
" and ",
len(container) - 1,
length - 1,
)
if i >= 0:
return i
return i + len(container)
return i + length


@always_inline
fn normalize_index[
IdxType: Indexer, //, container_name: StringLiteral
](idx: IdxType, length: Int) -> Int:
"""Normalize the given index value to a valid index value for the given container length.
If the provided value is negative, the `index + container_length` is returned.
Parameters:
IdxType: A type that can be used as an index.
container_name: The name of the container. Used for the error message.
Args:
idx: The index value to normalize.
length: The container length to normalize the index for.
Returns:
The normalized index value.
"""
return Int(normalize_index[container_name](idx, UInt(length)))
25 changes: 4 additions & 21 deletions stdlib/src/collections/inline_array.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -235,15 +235,8 @@ struct InlineArray[
Returns:
A reference to the item at the given index.
"""

@parameter
if _type_is_eq[I, UInt]():
return self.unsafe_get(idx)
else:
var normalized_index = normalize_index["InlineArray"](
Int(idx), self
)
return self.unsafe_get(normalized_index)
var normalized_index = normalize_index["InlineArray"](idx, len(self))
return self.unsafe_get(normalized_index)

@always_inline
fn __getitem__[
Expand All @@ -259,18 +252,8 @@ struct InlineArray[
A reference to the item at the given index.
"""
constrained[-size <= Int(idx) < size, "Index must be within bounds."]()

@parameter
if _type_is_eq[I, UInt]():
return self.unsafe_get(idx)
else:
var normalized_idx = Int(idx)

@parameter
if Int(idx) < 0:
normalized_idx += size

return self.unsafe_get(normalized_idx)
alias normalized_index = normalize_index["InlineArray"](idx, size)
return self.unsafe_get(normalized_index)

# ===------------------------------------------------------------------=== #
# Trait implementations
Expand Down
20 changes: 15 additions & 5 deletions stdlib/src/collections/linked_list.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -232,21 +232,25 @@ struct LinkedList[ElementType: WritableCollectionElement]:
curr = curr[].next
return new^

fn _get_node_ptr(ref self, index: Int) -> UnsafePointer[Node[ElementType]]:
fn _get_node_ptr[
I: Indexer
](ref self, index: I) -> UnsafePointer[Node[ElementType]]:
"""Get a pointer to the node at the specified index.
This method optimizes traversal by starting from either the head or tail
depending on which is closer to the target index.
Parameters:
I: A type that can be used as an index.
Args:
index: The index of the node to get.
Returns:
A pointer to the node at the specified index.
"""
var l = len(self)
var i = normalize_index[container_name="LinkedList"](index, self)
debug_assert(0 <= i < l, "index out of bounds")
var i = normalize_index["LinkedList"](index, l)
var mid = l // 2
if i <= mid:
var curr = self._head
Expand All @@ -259,9 +263,12 @@ struct LinkedList[ElementType: WritableCollectionElement]:
curr = curr[].prev
return curr

fn __getitem__(ref self, index: Int) -> ref [self] ElementType:
fn __getitem__[I: Indexer](ref self, index: I) -> ref [self] ElementType:
"""Get the element at the specified index.
Parameters:
I: A type that can be used as an index.
Args:
index: The index of the element to get.
Expand All @@ -271,9 +278,12 @@ struct LinkedList[ElementType: WritableCollectionElement]:
debug_assert(len(self) > 0, "unable to get item from empty list")
return self._get_node_ptr(index)[].value

fn __setitem__(mut self, index: Int, owned value: ElementType):
fn __setitem__[I: Indexer](mut self, index: I, owned value: ElementType):
"""Set the element at the specified index.
Parameters:
I: A type that can be used as an index.
Args:
index: The index of the element to set.
value: The new value to set.
Expand Down
2 changes: 1 addition & 1 deletion stdlib/src/collections/string/string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,7 @@ struct String(
A new string containing the character at the specified position.
"""
# TODO(#933): implement this for unicode when we support llvm intrinsic evaluation at compile time
var normalized_idx = normalize_index["String"](idx, self)
var normalized_idx = normalize_index["String"](idx, len(self))
var buf = Self._buffer_type(capacity=1)
buf.append(self._buffer[normalized_idx])
buf.append(0)
Expand Down
91 changes: 68 additions & 23 deletions stdlib/test/collections/test_index_normalization.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -18,33 +18,78 @@ from testing import assert_equal


def test_out_of_bounds_message():
l = List[Int](1, 2)
# CHECK: index out of bounds: 2
_ = normalize_index["List"](2, l)
# CHECK: index out of bounds: 2
_ = normalize_index["List"](UInt(2), l)
# CHECK: index out of bounds: -3
_ = normalize_index["List"](-3, l)
# CHECK: index out of bounds
_ = normalize_index[""](2, 2)
# CHECK: index out of bounds
_ = normalize_index[""](UInt(2), 2)
# CHECK: index out of bounds
_ = normalize_index[""](2, UInt(2))
# CHECK: index out of bounds
_ = normalize_index[""](UInt(2), UInt(2))
# CHECK: index out of bounds
_ = normalize_index[""](UInt8(2), 2)

l2 = List[Int]()
# CHECK: indexing into a List that has 0 elements
_ = normalize_index["List"](2, l2)
# CHECK: indexing into a List that has 0 elements
_ = normalize_index["List"](UInt(2), l2)
# CHECK: index out of bounds
_ = normalize_index[""](-3, 2)
# CHECK: index out of bounds
_ = normalize_index[""](-3, UInt(2))
# CHECK: index out of bounds
_ = normalize_index[""](Int8(-3), 2)

# CHECK: index out of bounds
_ = normalize_index[""](2, 0)
# CHECK: index out of bounds
_ = normalize_index[""](UInt(2), 0)
# CHECK: index out of bounds
_ = normalize_index[""](2, UInt(0))
# CHECK: index out of bounds
_ = normalize_index[""](UInt(2), UInt(0))

# CHECK: index out of bounds
_ = normalize_index[""](UInt.MAX, 10)
# CHECK: index out of bounds
_ = normalize_index[""](UInt.MAX, UInt(10))
# CHECK: index out of bounds
_ = normalize_index[""](UInt.MAX, UInt.MAX)
# CHECK: index out of bounds
_ = normalize_index[""](UInt.MAX, UInt.MAX - 10)

# CHECK: Overflow Error
_ = normalize_index[""](-1, UInt(Int.MAX + 1))


def test_normalize_index():
container = List[Int](1, 1, 1, 1)
assert_equal(normalize_index[""](-4, container), 0)
assert_equal(normalize_index[""](-3, container), 1)
assert_equal(normalize_index[""](-2, container), 2)
assert_equal(normalize_index[""](-1, container), 3)
assert_equal(normalize_index[""](0, container), 0)
assert_equal(normalize_index[""](1, container), 1)
assert_equal(normalize_index[""](2, container), 2)
assert_equal(normalize_index[""](3, container), 3)
assert_equal(normalize_index[""](UInt(0), container), 0)
assert_equal(normalize_index[""](UInt(3), container), 3)
assert_equal(normalize_index[""](-3, 3), 0)
assert_equal(normalize_index[""](-2, 3), 1)
assert_equal(normalize_index[""](-1, 3), 2)
assert_equal(normalize_index[""](0, 3), 0)
assert_equal(normalize_index[""](1, 3), 1)
assert_equal(normalize_index[""](2, 3), 2)

assert_equal(normalize_index[""](-3, UInt(3)), 0)
assert_equal(normalize_index[""](-2, UInt(3)), 1)
assert_equal(normalize_index[""](-1, UInt(3)), 2)
assert_equal(normalize_index[""](0, UInt(3)), 0)
assert_equal(normalize_index[""](1, UInt(3)), 1)
assert_equal(normalize_index[""](2, UInt(3)), 2)

assert_equal(normalize_index[""](UInt(0), UInt(3)), 0)
assert_equal(normalize_index[""](UInt(1), UInt(3)), 1)
assert_equal(normalize_index[""](UInt(2), UInt(3)), 2)

assert_equal(normalize_index[""](Int8(-3), 3), 0)
assert_equal(normalize_index[""](Int8(-2), 3), 1)
assert_equal(normalize_index[""](Int8(-1), 3), 2)
assert_equal(normalize_index[""](Int8(0), 3), 0)
assert_equal(normalize_index[""](Int8(1), 3), 1)
assert_equal(normalize_index[""](Int8(2), 3), 2)

assert_equal(normalize_index[""](UInt8(0), 3), 0)
assert_equal(normalize_index[""](UInt8(1), 3), 1)
assert_equal(normalize_index[""](UInt8(2), 3), 2)

assert_equal(normalize_index[""](UInt(1), UInt.MAX), 1)
assert_equal(normalize_index[""](UInt.MAX - 5, UInt.MAX), UInt.MAX - 5)


def main():
Expand Down

0 comments on commit 08417d9

Please sign in to comment.