Skip to content

Commit

Permalink
[stdlib] Use Indexer in normalize_index
Browse files Browse the repository at this point in the history
Optimize normalize_index for UInt Indexer type.

Signed-off-by: Yinon Burgansky <yinonburgansky@gmail.com>
  • Loading branch information
yinonburgansky committed Jan 18, 2025
1 parent 0d00718 commit 3abe545
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 18 deletions.
52 changes: 35 additions & 17 deletions stdlib/src/collections/_index_normalization.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,19 @@
"""The utilities provided in this module help normalize the access
to data elements in arrays."""

from sys.intrinsics import _type_is_eq


@always_inline
fn normalize_index[
ContainerType: Sized, //, container_name: StringLiteral
](idx: Int, container: ContainerType) -> Int:
I: Indexer, ContainerType: Sized, //, container_name: StringLiteral
](idx: I, container: ContainerType) -> UInt:
"""Normalize the given index value to a valid index value for the given container length.
If the provided value is negative, the `index + container_length` is returned.
Parameters:
I: A type that can be used as an index.
ContainerType: The type of the container. Must have a `__len__` method.
container_name: The name of the container. Used for the error message.
Expand All @@ -39,18 +42,33 @@ fn normalize_index[
container_name,
" that has 0 elements",
)
debug_assert[assert_mode="safe", cpu_only=True](
-len(container) <= idx < len(container),
container_name,
" has length: ",
len(container),
" index out of bounds: ",
idx,
" should be between ",
-len(container),
" and ",
len(container) - 1,
)
if idx >= 0:
return idx
return idx + len(container)

@parameter
if _type_is_eq[I, UInt]():
var i = rebind[UInt](idx)
debug_assert[assert_mode="safe", cpu_only=True](
i < len(container),
container_name,
" index out of bounds: ",
i,
" should be less than ",
len(container),
)
return i
else:
var i = Int(idx)
debug_assert[assert_mode="safe", cpu_only=True](
-len(container) <= i < len(container),
container_name,
" has length: ",
len(container),
" index out of bounds: ",
i,
" should be between ",
-len(container),
" and ",
len(container) - 1,
)
if i >= 0:
return i
return i + len(container)
2 changes: 1 addition & 1 deletion stdlib/src/collections/string/string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,7 @@ struct String(
A new string containing the character at the specified position.
"""
# TODO(#933): implement this for unicode when we support llvm intrinsic evaluation at compile time
var normalized_idx = normalize_index["String"](index(idx), self)
var normalized_idx = normalize_index["String"](idx, self)
var buf = Self._buffer_type(capacity=1)
buf.append(self._buffer[normalized_idx])
buf.append(0)
Expand Down
6 changes: 6 additions & 0 deletions stdlib/test/collections/test_index_normalization.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,16 @@ def test_out_of_bounds_message():
l = List[Int](1, 2)
# CHECK: index out of bounds: 2
_ = normalize_index["List"](2, l)
# CHECK: index out of bounds: 2
_ = normalize_index["List"](UInt(2), l)
# CHECK: index out of bounds: -3
_ = normalize_index["List"](-3, l)

l2 = List[Int]()
# CHECK: indexing into a List that has 0 elements
_ = normalize_index["List"](2, l2)
# CHECK: indexing into a List that has 0 elements
_ = normalize_index["List"](UInt(2), l2)


def test_normalize_index():
Expand All @@ -39,6 +43,8 @@ def test_normalize_index():
assert_equal(normalize_index[""](1, container), 1)
assert_equal(normalize_index[""](2, container), 2)
assert_equal(normalize_index[""](3, container), 3)
assert_equal(normalize_index[""](UInt(0), container), 0)
assert_equal(normalize_index[""](UInt(3), container), 3)


def main():
Expand Down

0 comments on commit 3abe545

Please sign in to comment.