Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] Refactor memset() to be generic #3577

Open
wants to merge 28 commits into
base: nightly
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
3bcfe19
add string add with stringslice and refactor list resize to use new g…
martinvuyk Sep 30, 2024
95c3c57
fix issues
martinvuyk Sep 30, 2024
5f92a40
fix memset test
martinvuyk Sep 30, 2024
f5227f6
improve memset tests
martinvuyk Sep 30, 2024
070d4c8
Merge remote-tracking branch 'upstream/nightly' into add-string-add-s…
martinvuyk Sep 30, 2024
1451f4f
add fixme comment
martinvuyk Sep 30, 2024
38af4e7
add fixme comment
martinvuyk Sep 30, 2024
b6f58b2
add reference to feature request #3581
martinvuyk Oct 1, 2024
7b2e98f
Merge remote-tracking branch 'upstream/nightly' into add-string-add-s…
martinvuyk Oct 1, 2024
80b0a88
Merge branch 'nightly' into add-string-add-stringslice
martinvuyk Oct 1, 2024
e1aac32
remove usage of resize
martinvuyk Oct 2, 2024
47b2365
Merge branch 'add-string-add-stringslice' of github.com:martinvuyk/mo…
martinvuyk Oct 2, 2024
e8e5b4a
remove add string add stringslice from PR
martinvuyk Oct 2, 2024
1352c55
fix detail
martinvuyk Oct 2, 2024
1f4038a
Merge branch 'nightly' into add-string-add-stringslice
martinvuyk Oct 2, 2024
3b4a137
Merge branch 'nightly' into add-string-add-stringslice
martinvuyk Oct 6, 2024
34a6ad0
Merge remote-tracking branch 'upstream/nightly' into add-string-add-s…
martinvuyk Oct 21, 2024
3eca09a
Merge branch 'nightly' into add-string-add-stringslice
martinvuyk Nov 6, 2024
bce3b97
Merge branch 'nightly' into add-string-add-stringslice
martinvuyk Nov 12, 2024
c806549
Merge branch 'nightly' into add-string-add-stringslice
martinvuyk Nov 22, 2024
8ad72c7
fix detail
martinvuyk Nov 22, 2024
da475d2
start removing DType.get_dtype()
martinvuyk Nov 25, 2024
6bb95b8
Merge remote-tracking branch 'upstream/nightly' into add-string-add-s…
martinvuyk Nov 25, 2024
b6ed84d
remove DType.get_dtype() and undo changes to List
martinvuyk Nov 25, 2024
ba5d396
fix detail
martinvuyk Nov 25, 2024
c173093
fix detail
martinvuyk Nov 25, 2024
0d355fb
Merge remote-tracking branch 'upstream/nightly' into add-string-add-s…
martinvuyk Dec 3, 2024
b4f3552
mojo format
martinvuyk Dec 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 51 additions & 15 deletions stdlib/src/memory/memory.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -266,42 +266,78 @@ fn memcpy[

@always_inline("nodebug")
fn _memset_impl[
address_space: AddressSpace
D: DType, address_space: AddressSpace
](
ptr: UnsafePointer[Byte, address_space=address_space],
value: Byte,
ptr: UnsafePointer[Scalar[D], address_space=address_space],
value: Scalar[D],
count: Int,
):
alias simd_width = simdwidthof[Byte]()
alias simd_width = simdwidthof[Scalar[D]]()
var vector_end = _align_down(count, simd_width)

for i in range(0, vector_end, simd_width):
ptr.store(i, SIMD[DType.uint8, simd_width](value))
ptr.store(i, SIMD[D, simd_width](value))

for i in range(vector_end, count):
ptr.store(i, value)


@always_inline
fn memset[
type: AnyType, address_space: AddressSpace
](
ptr: UnsafePointer[type, address_space=address_space],
value: Byte,
count: Int,
):
D: DType
](ptr: UnsafePointer[Scalar[D]], value: Scalar[D], count: Int):
"""Fills memory with the given value.

Parameters:
type: The element dtype.
address_space: The address space of the pointer.
D: The element dtype.

Args:
ptr: UnsafePointer to the beginning of the memory block to fill.
value: The value to fill with.
count: Number of elements to fill (in elements, not bytes).
"""
_memset_impl(ptr, value, count)


@always_inline
fn memset[D: DType](ptr: UnsafePointer[Scalar[D]], value: Int, count: Int):
"""Fills memory with the given value.

Parameters:
D: The element dtype.

Args:
ptr: UnsafePointer to the beginning of the memory block to fill.
value: The value to fill with.
count: Number of elements to fill (in elements, not bytes).
"""
_memset_impl(ptr.bitcast[Byte](), value, count * sizeof[type]())
_memset_impl(ptr, Scalar[D](value), count)


# FIXME(#3581): this should only be for trivial types, but constraining it with
# AnyTrivialRegType would make building generics on top of this harder
@always_inline
fn memset[
T: AnyType, D: DType
](ptr: UnsafePointer[T], value: Scalar[D], count: Int):
"""Fills memory with the given value.

Parameters:
T: The element type.
D: The value's dtype.

Args:
ptr: UnsafePointer to the beginning of the memory block to fill.
value: The value to fill with.
count: Number of elements to fill (in elements, not bytes).
"""

alias size = sizeof[Scalar[D]]()
constrained[
sizeof[T]() == size,
"value to fill must be the same bitwidth as the type",
]()
_memset_impl(ptr.bitcast[Scalar[D]](), value, count * size)


# ===----------------------------------------------------------------------===#
Expand All @@ -323,7 +359,7 @@ fn memset_zero[
ptr: UnsafePointer to the beginning of the memory block to fill.
count: Number of elements to fill (in elements, not bytes).
"""
memset(ptr, 0, count)
_memset_impl(ptr.bitcast[UInt8](), UInt8(0), count * sizeof[type]())


@always_inline
Expand Down
53 changes: 26 additions & 27 deletions stdlib/test/memory/test_memory.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ from memory import (
memcpy,
memset,
memset_zero,
stack_allocation,
)
from testing import (
assert_almost_equal,
Expand Down Expand Up @@ -275,40 +276,38 @@ def test_memcmp_extensive():

def test_memset():
var pair = Pair(1, 2)

var ptr = UnsafePointer.address_of(pair)
memset_zero(ptr, 1)

memset(ptr.bitcast[UInt64](), 0, 1)
assert_equal(pair.lo, 0)
assert_equal(pair.hi, 0)
assert_equal(pair.hi, 2)

pair.lo = 1
pair.hi = 2
pair = Pair(1, 2)
memset_zero(ptr, 1)

assert_equal(pair.lo, 0)
assert_equal(pair.hi, 0)

var buf0 = UnsafePointer[Int32].alloc(2)
memset(buf0, 1, 2)
assert_equal(buf0.load(0), 16843009)
memset(buf0, -1, 2)
assert_equal(buf0.load(0), -1)
buf0.free()

var buf1 = UnsafePointer[Int8].alloc(2)
memset(buf1, 5, 2)
assert_equal(buf1.load(0), 5)
buf1.free()

var buf3 = UnsafePointer[Int32].alloc(2)
memset(buf3, 1, 2)
memset_zero[count=2](buf3)
assert_equal(buf3.load(0), 0)
assert_equal(buf3.load(1), 0)
buf3.free()

_ = pair
fn test_dtype[D: DType]() raises:
var buf1 = stack_allocation[2, Scalar[D]]()
memset(buf1, 5, 2)
assert_equal(buf1.load(0), 5)
assert_equal(buf1.load(1), 5)
var buf2 = UnsafePointer[Scalar[D]].alloc(2)
memset(buf2, 5, 2)
assert_equal(buf2.load(0), 5)
assert_equal(buf2.load(1), 5)
buf2.free()

test_dtype[DType.uint8]()
test_dtype[DType.uint16]()
test_dtype[DType.uint32]()
test_dtype[DType.uint64]()
test_dtype[DType.int8]()
test_dtype[DType.int16]()
test_dtype[DType.int32]()
test_dtype[DType.int64]()
test_dtype[DType.float16]()
test_dtype[DType.float32]()
test_dtype[DType.float64]()


def test_pointer_string():
Expand Down
3 changes: 2 additions & 1 deletion stdlib/test/python/my_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def __init__(self, bar):

class AbstractPerson(ABC):
@abstractmethod
def method(self): ...
def method(self):
...


def my_function(name):
Expand Down
Loading