Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] Move StringRef find() implementation to Span #3548

Open
wants to merge 66 commits into
base: nightly
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
45f9bec
move StrRef find() implementation to Span
martinvuyk Sep 25, 2024
070ac79
fix details
martinvuyk Sep 25, 2024
9950729
fix details
martinvuyk Sep 25, 2024
5a1cf5d
fix details
martinvuyk Sep 25, 2024
c0dac09
fix details
martinvuyk Sep 25, 2024
de6e116
fix details
martinvuyk Sep 25, 2024
d03ded0
fix details
martinvuyk Sep 25, 2024
a4c2a57
Merge branch 'nightly' into move-strref-find-impl-to-span
martinvuyk Oct 1, 2024
77a333f
Merge branch 'nightly' into move-strref-find-impl-to-span
martinvuyk Oct 1, 2024
f9a2eed
Merge branch 'nightly' into move-strref-find-impl-to-span
martinvuyk Oct 1, 2024
5e8fb25
Merge remote-tracking branch 'upstream/nightly' into move-strref-find…
martinvuyk Oct 3, 2024
6ca281e
fix stringref find and add fixme for later
martinvuyk Oct 3, 2024
6535e8b
fix detail
martinvuyk Oct 3, 2024
f6f9144
fix detail
martinvuyk Oct 3, 2024
964f4d5
Merge branch 'nightly' into move-strref-find-impl-to-span
martinvuyk Oct 6, 2024
828490b
Merge branch 'nightly' into move-strref-find-impl-to-span
martinvuyk Oct 10, 2024
015e5d4
Merge remote-tracking branch 'upstream/nightly' into move-strref-find…
martinvuyk Oct 13, 2024
d84530e
fix detail
martinvuyk Oct 13, 2024
2f37b51
fix detail
martinvuyk Oct 13, 2024
f1cb5d1
fix detail
martinvuyk Oct 13, 2024
702cb00
Merge remote-tracking branch 'upstream/nightly' into move-strref-find…
martinvuyk Oct 16, 2024
5a105a2
remove fixme comment
martinvuyk Oct 16, 2024
e4b495e
Merge branch 'nightly' into move-strref-find-impl-to-span
martinvuyk Oct 17, 2024
650d63c
Merge remote-tracking branch 'upstream/nightly' into move-strref-find…
martinvuyk Oct 21, 2024
f9a7329
fix detail
martinvuyk Oct 21, 2024
c7638f4
Merge remote-tracking branch 'upstream/nightly' into move-strref-find…
martinvuyk Oct 29, 2024
c29ab9d
refactor memrchr and memrmem and add parameter to use memchr and memr…
martinvuyk Oct 29, 2024
21c2e7a
Merge branch 'nightly' into move-strref-find-impl-to-span
martinvuyk Oct 29, 2024
ef6a1ec
fix a bunch of stuff
martinvuyk Oct 29, 2024
9898d6c
fix detail
martinvuyk Oct 29, 2024
a57d446
fix details
martinvuyk Oct 29, 2024
5f69209
fix details
martinvuyk Oct 29, 2024
feee988
fix details
martinvuyk Oct 29, 2024
ec78800
fix details
martinvuyk Oct 29, 2024
1d27970
fix details
martinvuyk Oct 29, 2024
de85d7e
Merge remote-tracking branch 'upstream/nightly' into move-strref-find…
martinvuyk Oct 31, 2024
d311761
fix bugs in memrchr and memrmem
martinvuyk Oct 31, 2024
c5be607
Merge remote-tracking branch 'upstream/nightly' into move-strref-find…
martinvuyk Oct 31, 2024
3f89314
fix details
martinvuyk Oct 31, 2024
09d92b7
update changelog
martinvuyk Oct 31, 2024
646bb90
Merge remote-tracking branch 'upstream/nightly' into move-strref-find…
martinvuyk Oct 31, 2024
162332a
Merge remote-tracking branch 'upstream/nightly' into move-strref-find…
martinvuyk Nov 4, 2024
22857e2
update to use pack_bits
martinvuyk Nov 4, 2024
f9e8fda
add overloads
martinvuyk Nov 4, 2024
8cd3b56
fix detail
martinvuyk Nov 4, 2024
a619944
fix detail
martinvuyk Nov 4, 2024
7213788
use var again
martinvuyk Nov 4, 2024
cfe5a5f
use var again
martinvuyk Nov 4, 2024
c3c0cd3
fix detail
martinvuyk Nov 4, 2024
8b88b38
fix detail
martinvuyk Nov 4, 2024
3fd3529
fix detail
martinvuyk Nov 4, 2024
e323392
fix detail
martinvuyk Nov 4, 2024
dca7728
Merge branch 'nightly' into move-strref-find-impl-to-span
martinvuyk Nov 6, 2024
8c29f12
Merge branch 'nightly' into move-strref-find-impl-to-span
martinvuyk Nov 12, 2024
aa791cd
fix unsafe ptr constructor
martinvuyk Nov 12, 2024
de191aa
Merge branch 'nightly' into move-strref-find-impl-to-span
martinvuyk Nov 14, 2024
6926293
Merge remote-tracking branch 'upstream/nightly' into move-strref-find…
martinvuyk Nov 21, 2024
68c29f5
Merge remote-tracking branch 'upstream/nightly' into move-strref-find…
martinvuyk Dec 3, 2024
ae73214
mojo format
martinvuyk Dec 3, 2024
74617ea
Merge remote-tracking branch 'upstream/nightly' into move-strref-find…
martinvuyk Dec 9, 2024
ef863cf
fix after merge with nightly
martinvuyk Dec 9, 2024
7b43af6
Merge remote-tracking branch 'upstream/nightly' into move-strref-find…
martinvuyk Dec 19, 2024
52a0c90
try fix markdownlint
martinvuyk Dec 19, 2024
1a8f348
Merge remote-tracking branch 'upstream/nightly' into move-strref-find…
martinvuyk Dec 26, 2024
183a607
fix after merge
martinvuyk Dec 26, 2024
53557d2
fix after merge
martinvuyk Dec 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions stdlib/src/builtin/string_literal.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ struct StringLiteral(
Returns:
The offset of `substr` relative to the beginning of the string.
"""
return StringRef(self).find(substr, start=start)
return self.as_string_slice().find(substr.as_string_slice(), start)

fn rfind(self, substr: StringLiteral, start: Int = 0) -> Int:
"""Finds the offset of the last occurrence of `substr` starting at
Expand All @@ -404,7 +404,7 @@ struct StringLiteral(
Returns:
The offset of `substr` relative to the beginning of the string.
"""
return StringRef(self).rfind(substr, start=start)
return self.as_string_slice().rfind(substr.as_string_slice(), start)

fn replace(self, old: StringLiteral, new: StringLiteral) -> StringLiteral:
"""Return a copy of the string with all occurrences of substring `old`
Expand Down
6 changes: 1 addition & 5 deletions stdlib/src/collections/string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -1524,7 +1524,6 @@ struct String(
Returns:
The offset of `substr` relative to the beginning of the string.
"""

return self.as_string_slice().find(substr.as_string_slice(), start)

fn rfind(self, substr: String, start: Int = 0) -> Int:
Expand All @@ -1538,10 +1537,7 @@ struct String(
Returns:
The offset of `substr` relative to the beginning of the string.
"""

return self._strref_dangerous().rfind(
substr._strref_dangerous(), start=start
)
return self.as_string_slice().rfind(substr.as_string_slice(), start)

fn isspace(self) -> Bool:
"""Determines whether every character in the given String is a
Expand Down
196 changes: 195 additions & 1 deletion stdlib/src/utils/span.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@ from utils import Span
"""

from collections import InlineArray
from memory import Reference, UnsafePointer
from memory import Reference, UnsafePointer, bitcast, memcmp
from sys.intrinsics import _type_is_eq
from builtin.builtin_list import _lit_mut_cast
from sys import simdwidthof
from bit import count_trailing_zeros
from builtin.dtype import _uint_type_of_width


@value
Expand Down Expand Up @@ -335,3 +338,194 @@ struct Span[
return Span[T, _lit_mut_cast[lifetime, False].result](
unsafe_ptr=self._data, len=self._len
)

fn find[
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question It seems a bit weird from an API design perspective to have this "find" function in span, can you help me understand why we'd want this? It feels more coupled to string algorithms, and Span isn't just for a view over string data.

Copy link
Contributor Author

@martinvuyk martinvuyk Oct 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be used if you want to find occurrences of a specific scalar value sequence. My main thought was that we'll lose the ability to do that once StringSlice.find() is fixed to use unicode codepoints and it introduces quite some overhead for algorithms that are faster using raw bytes. We could also make this private WDYT?

Also, once PR's #3577 DType.get_dtype() gets merged, List.index() can also delegate to this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JoeLoser gentle ping and also to add something, since it seems like we are moving toward unifying Python's bytes into Span[UInt8]in PR #3636, request #3634 and many other places. Python's bytes.find is actually a method, and many other things that we put in StringSlice and StringRef are actually operations from bytes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a great change for overall performance when manipulating slices. It opens a lot of possibilities.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JoeLoser per discussion on discord, Span[T].find is intended to be memchr or memmem as appropriate, but that the subspan of bytes must be aligned to alignof[T](). We can probably expand it to any T: Trivial once we have a Trivial trait.

The proper string version likely needs be specialized for UTF-8 comparisons. @martinvuyk and I have discussed an "AsciiString" which can more aggressively optimize for using byte-oriented intrinsics since UTF-8 is annoying to do with SIMD.

D: DType, //, from_left: Bool = True
](self: Span[Scalar[D]], subseq: Span[Scalar[D]], start: Int = 0) -> Int:
"""Finds the offset of the first occurrence of `subseq` starting at
`start`. If not found, returns -1.

Parameters:
D: The `DType` of the Scalar.
from_left: Whether to search the first occurrence from the left.

Args:
subseq: The sub sequence to find.
start: The offset from which to find.

Returns:
The offset of `subseq` relative to the beginning of the `Span`.
"""
var _len = len(self)

if not subseq:

@parameter
if from_left:
return 0
else:
return _len

if _len < len(subseq) + start:
return -1

var start_norm = max(_len + start, 0) if start < 0 else min(_len, start)
var haystack = __type_of(self)(
unsafe_ptr=self.unsafe_ptr() + start_norm, len=_len - start_norm
)
var loc: UnsafePointer[Scalar[D]]

@parameter
if from_left:
loc = _memmem(haystack, subseq)
else:
loc = _memrmem(haystack, subseq)

return int(loc) - int(self.unsafe_ptr()) if loc else -1

@always_inline
fn rfind[
D: DType, //
](self: Span[Scalar[D]], subseq: Span[Scalar[D]], start: Int = 0) -> Int:
"""Finds the offset of the last occurrence of `subseq` starting at
`start`. If not found, returns -1.

Parameters:
D: The `DType` of the Scalar.

Args:
subseq: The sub sequence to find.
start: The offset from which to find.

Returns:
The offset of `subseq` relative to the beginning of the `Span`.
"""
return self.find[from_left=False](subseq, start)


# ===----------------------------------------------------------------------===#
# Utilities
# ===----------------------------------------------------------------------===#


@always_inline
fn _align_down(value: Int, alignment: Int) -> Int:
return value._positive_div(alignment) * alignment


@always_inline
fn _memchr[
type: DType
](
source: UnsafePointer[Scalar[type]], char: Scalar[type], len: Int
) -> UnsafePointer[Scalar[type]]:
if not len:
return UnsafePointer[Scalar[type]]()
alias bool_mask_width = simdwidthof[DType.bool]()
var first_needle = SIMD[type, bool_mask_width](char)
var vectorized_end = _align_down(len, bool_mask_width)

for i in range(0, vectorized_end, bool_mask_width):
var bool_mask = source.load[width=bool_mask_width](i) == first_needle
var mask = bitcast[_uint_type_of_width[bool_mask_width]()](bool_mask)
if mask:
return source + int(i + count_trailing_zeros(mask))

for i in range(vectorized_end, len):
if source[i] == char:
return source + i
return UnsafePointer[Scalar[type]]()


@always_inline
fn _memmem[
type: DType
](
haystack_span: Span[Scalar[type]], needle_span: Span[Scalar[type]]
) -> UnsafePointer[Scalar[type]]:
var haystack = haystack_span.unsafe_ptr()
var haystack_len = len(haystack_span)
var needle = needle_span.unsafe_ptr()
var needle_len = len(needle_span)
if not needle_len:
return haystack
if needle_len > haystack_len:
return UnsafePointer[Scalar[type]]()
if needle_len == 1:
return _memchr[type](haystack, needle[0], haystack_len)

alias bool_mask_width = simdwidthof[DType.bool]()
var vectorized_end = _align_down(
haystack_len - needle_len + 1, bool_mask_width
)

var first_needle = SIMD[type, bool_mask_width](needle[0])
var last_needle = SIMD[type, bool_mask_width](needle[needle_len - 1])

for i in range(0, vectorized_end, bool_mask_width):
var first_block = haystack.load[width=bool_mask_width](i)
var last_block = haystack.load[width=bool_mask_width](
i + needle_len - 1
)

var eq_first = first_needle == first_block
var eq_last = last_needle == last_block

var bool_mask = eq_first & eq_last
var mask = bitcast[_uint_type_of_width[bool_mask_width]()](bool_mask)

while mask:
var offset = int(i + count_trailing_zeros(mask))
if memcmp(haystack + offset + 1, needle + 1, needle_len - 1) == 0:
return haystack + offset
mask = mask & (mask - 1)

# remaining partial block compare using byte-by-byte
#
for i in range(vectorized_end, haystack_len - needle_len + 1):
if haystack[i] != needle[0]:
continue

if memcmp(haystack + i + 1, needle + 1, needle_len - 1) == 0:
return haystack + i
_ = haystack_span, needle_span
return UnsafePointer[Scalar[type]]()


@always_inline
fn _memrchr[
type: DType
](
source: UnsafePointer[Scalar[type]], char: Scalar[type], len: Int
) -> UnsafePointer[Scalar[type]]:
if not len:
return UnsafePointer[Scalar[type]]()
for i in reversed(range(len)):
if source[i] == char:
return source + i
return UnsafePointer[Scalar[type]]()


@always_inline
fn _memrmem[
type: DType
](
haystack_span: Span[Scalar[type]], needle_span: Span[Scalar[type]]
) -> UnsafePointer[Scalar[type]]:
var haystack = haystack_span.unsafe_ptr()
var haystack_len = len(haystack_span)
var needle = needle_span.unsafe_ptr()
var needle_len = len(needle_span)
if not needle_len:
return haystack
if needle_len > haystack_len:
return UnsafePointer[Scalar[type]]()
if needle_len == 1:
return _memrchr[type](haystack, needle[0], haystack_len)
for i in reversed(range(haystack_len - needle_len + 1)):
if haystack[i] != needle[0]:
continue
if memcmp(haystack + i + 1, needle + 1, needle_len - 1) == 0:
return haystack + i
_ = haystack_span, needle_span
return UnsafePointer[Scalar[type]]()
31 changes: 13 additions & 18 deletions stdlib/src/utils/string_slice.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -568,27 +568,22 @@ struct StringSlice[
Returns:
The offset of `substr` relative to the beginning of the string.
"""
if not substr:
return 0
# FIXME(#3526): this should return unicode codepoint offsets
return self.as_bytes_span().find(substr.as_bytes_span(), start)

if self.byte_length() < substr.byte_length() + start:
return -1

# The substring to search within, offset from the beginning if `start`
# is positive, and offset from the end if `start` is negative.
var haystack_str = self._from_start(start)

var loc = stringref._memmem(
haystack_str.unsafe_ptr(),
haystack_str.byte_length(),
substr.unsafe_ptr(),
substr.byte_length(),
)
fn rfind(self, substr: StringSlice, start: Int = 0) -> Int:
"""Finds the offset of the last occurrence of `substr` starting at
`start`. If not found, returns -1.

if not loc:
return -1
Args:
substr: The substring to find.
start: The offset from which to find.

return int(loc) - int(self.unsafe_ptr())
Returns:
The offset of `substr` relative to the beginning of the string.
"""
# FIXME(#3526): this should return unicode codepoint offsets
return self.as_bytes_span().rfind(substr.as_bytes_span(), start)

fn isspace(self) -> Bool:
"""Determines whether every character in the given StringSlice is a
Expand Down
Loading