Skip to content

Commit

Permalink
[External] [stdlib] Add __init__(*, from_bits) and to_bits() to `…
Browse files Browse the repository at this point in the history
…SIMD` (#50172)

[External] [stdlib] Add `__init__(*, from_bits)` and `to_bits()` to
`SIMD`

Co-authored-by: soraros <soraros@users.noreply.github.com>
Closes #3680
MODULAR_ORIG_COMMIT_REV_ID: d36f1f1fd14105f89cff09fb77c6bcf5be662a03
  • Loading branch information
soraros authored and modularbot committed Nov 1, 2024
1 parent 4adf2b2 commit b99d031
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 59 deletions.
72 changes: 42 additions & 30 deletions stdlib/src/builtin/simd.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ from sys import sizeof, alignof
from .dtype import (
_get_dtype_printf_format,
_integral_type_of,
_unsigned_integral_type_of,
_scientific_notation_digits,
)
from .io import _printf, _snprintf_scalar
Expand Down Expand Up @@ -492,6 +493,20 @@ struct SIMD[type: DType, size: Int](
)
)

fn __init__[
int_type: DType, //
](inout self, *, from_bits: SIMD[int_type, size]):
"""Initializes the SIMD vector from the bits of an integral SIMD vector.
Parameters:
int_type: The integral type of the input SIMD vector.
Args:
from_bits: The SIMD vector to copy the bits from.
"""
constrained[int_type.is_integral(), "the SIMD type must be integral"]()
self = bitcast[type, size](from_bits)

# ===-------------------------------------------------------------------===#
# Operator dunders
# ===-------------------------------------------------------------------===#
Expand Down Expand Up @@ -776,9 +791,7 @@ struct SIMD[type: DType, size: Int](
# As a workaround, we roll our own implementation
@parameter
if has_neon() and type is DType.bfloat16:
var int_self = bitcast[_integral_type_of[type](), size](self)
var int_rhs = bitcast[_integral_type_of[type](), size](rhs)
return int_self == int_rhs
return self.to_bits() == rhs.to_bits()
else:
return __mlir_op.`pop.cmp`[pred = __mlir_attr.`#pop<cmp_pred eq>`](
self.value, rhs.value
Expand All @@ -801,9 +814,7 @@ struct SIMD[type: DType, size: Int](
# As a workaround, we roll our own implementation.
@parameter
if has_neon() and type is DType.bfloat16:
var int_self = bitcast[_integral_type_of[type](), size](self)
var int_rhs = bitcast[_integral_type_of[type](), size](rhs)
return int_self != int_rhs
return self.to_bits() != rhs.to_bits()
else:
return __mlir_op.`pop.cmp`[pred = __mlir_attr.`#pop<cmp_pred ne>`](
self.value, rhs.value
Expand Down Expand Up @@ -1514,9 +1525,9 @@ struct SIMD[type: DType, size: Int](
self
)

alias integral_type = FPUtils[type].integral_type
var m = self._float_to_bits[integral_type]()
return (m & (FPUtils[type].sign_mask() - 1))._bits_to_float[type]()
# FIXME: This should be an alias
var mask = FPUtils[type].exponent_mantissa_mask()
return Self(from_bits=self.to_bits() & mask)
else:
return (self < 0).select(-self, self)

Expand Down Expand Up @@ -1726,32 +1737,31 @@ struct SIMD[type: DType, size: Int](
if size > 1:
writer.write("]")

# FIXME: `_integral_type_of` doesn't work with `DType.bool`.
@always_inline
fn _bits_to_float[dest_type: DType](self) -> SIMD[dest_type, size]:
"""Bitcasts the integer value to a floating-point value.
Parameters:
dest_type: DType to bitcast the input SIMD vector to.
Returns:
A floating-point representation of the integer value.
"""
alias integral_type = FPUtils[type].integral_type
return bitcast[dest_type, size](self.cast[integral_type]())

@always_inline
fn _float_to_bits[dest_type: DType](self) -> SIMD[dest_type, size]:
"""Bitcasts the floating-point value to an integer value.
fn to_bits[
int_dtype: DType = _integral_type_of[type]()
](self) -> SIMD[int_dtype, size]:
"""Bitcasts the SIMD vector to an integer SIMD vector.
Parameters:
dest_type: DType to bitcast the input SIMD vector to.
int_dtype: The integer type to cast to.
Returns:
An integer representation of the floating-point value.
"""
alias integral_type = FPUtils[type].integral_type
var v = bitcast[integral_type, size](self)
return v.cast[dest_type]()
constrained[
int_dtype.is_integral(), "the target type must be integral"
]()
constrained[
bitwidthof[int_dtype]() >= bitwidthof[type](),
(
"the target integer type must be at least as wide as the source"
" type"
),
]()

return bitcast[_integral_type_of[type](), size](self).cast[int_dtype]()

fn _floor_ceil_trunc_impl[intrinsic: StringLiteral](self) -> Self:
constrained[
Expand Down Expand Up @@ -3221,14 +3231,16 @@ fn _floor(x: SIMD) -> __type_of(x):
alias bitwidth = bitwidthof[x.type]()
alias exponent_width = FPUtils[x.type].exponent_width()
alias mantissa_width = FPUtils[x.type].mantissa_width()
# FIXME: GH issue #3613
# alias mask = FPUtils[x.type].exponent_mask()
alias mask = (1 << exponent_width) - 1
alias bias = FPUtils[x.type].exponent_bias()
alias shift_factor = bitwidth - exponent_width - 1

var bits = bitcast[integral_type, x.size](x)
bits = x.to_bits()
var e = ((bits >> mantissa_width) & mask) - bias
bits = (e < shift_factor).select(
bits & ~((1 << (shift_factor - e)) - 1),
bits,
)
return bitcast[x.type, x.size](bits)
return __type_of(x)(from_bits=bits)
2 changes: 1 addition & 1 deletion stdlib/src/hashlib/_ahash.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ struct AHasher[key: U256](_Hasher):

@parameter
if new_data.type.is_floating_point():
v64 = new_data._float_to_bits[DType.uint64]()
v64 = new_data.to_bits().cast[DType.uint64]()
else:
v64 = new_data.cast[DType.uint64]()

Expand Down
42 changes: 17 additions & 25 deletions stdlib/src/math/math.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -418,11 +418,9 @@ fn exp2[
if type not in (DType.float32, DType.float64):
return exp2(x.cast[DType.float32]()).cast[type]()

alias integral_type = FPUtils[type].integral_type

var xc = x.clamp(-126, 126)

var m = xc.cast[integral_type]()
var m = xc.cast[__type_of(x.to_bits()).type]()

xc -= m.cast[type]()

Expand All @@ -436,11 +434,9 @@ fn exp2[
1.33336498402e-3,
),
](xc)

return (
r._float_to_bits[integral_type]()
+ (m << FPUtils[type].mantissa_width())
)._bits_to_float[type]()
return __type_of(r)(
from_bits=(r.to_bits() + (m << FPUtils[type].mantissa_width()))
)


# ===----------------------------------------------------------------------=== #
Expand Down Expand Up @@ -504,11 +500,9 @@ fn _ldexp_impl[
return res

alias integral_type = FPUtils[type].integral_type
var m: SIMD[integral_type, simd_width] = (
exp.cast[integral_type]() + FPUtils[type].exponent_bias()
)
var m = exp.cast[integral_type]() + FPUtils[type].exponent_bias()

return x * (m << FPUtils[type].mantissa_width())._bits_to_float[type]()
return x * __type_of(x)(from_bits=m << FPUtils[type].mantissa_width())


@always_inline
Expand Down Expand Up @@ -629,8 +623,8 @@ fn exp[

@always_inline
fn _frexp_mask1[
simd_width: Int, type: DType, integral_type: DType
]() -> SIMD[integral_type, simd_width]:
simd_width: Int, type: DType
]() -> SIMD[_integral_type_of[type](), simd_width]:
@parameter
if type is DType.float16:
return 0x7C00
Expand All @@ -645,8 +639,8 @@ fn _frexp_mask1[

@always_inline
fn _frexp_mask2[
simd_width: Int, type: DType, integral_type: DType
]() -> SIMD[integral_type, simd_width]:
simd_width: Int, type: DType
]() -> SIMD[_integral_type_of[type](), simd_width]:
@parameter
if type is DType.float16:
return 0x3800
Expand Down Expand Up @@ -681,22 +675,20 @@ fn frexp[
"""
# Based on the implementation in boost/simd/arch/common/simd/function/ifrexp.hpp
constrained[type.is_floating_point(), "must be a floating point value"]()
alias integral_type = _integral_type_of[type]()
alias zero = SIMD[type, simd_width](0)
alias T = SIMD[type, simd_width]
alias zero = T(0)
alias max_exponent = FPUtils[type].max_exponent() - 2
alias mantissa_width = FPUtils[type].mantissa_width()
var mask1 = _frexp_mask1[simd_width, type, integral_type]()
var mask2 = _frexp_mask2[simd_width, type, integral_type]()
var x_int = x._float_to_bits[integral_type]()
var mask1 = _frexp_mask1[simd_width, type]()
var mask2 = _frexp_mask2[simd_width, type]()
var x_int = x.to_bits()
var selector = x != zero
var exp = selector.select(
(((mask1 & x_int) >> mantissa_width) - max_exponent).cast[type](),
zero,
)
var frac = selector.select(
((x_int & ~mask1) | mask2)._bits_to_float[type](), zero
)
return StaticTuple[SIMD[type, simd_width], 2](frac, exp)
var frac = selector.select(T(from_bits=x_int & ~mask1 | mask2), zero)
return StaticTuple[size=2](frac, exp)


# ===----------------------------------------------------------------------=== #
Expand Down
1 change: 0 additions & 1 deletion stdlib/src/memory/memory.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ from sys import (
_libc as libc,
)
from collections import Optional
from builtin.dtype import _integral_type_of
from memory.pointer import AddressSpace, _GPUAddressSpace

# ===----------------------------------------------------------------------=== #
Expand Down
5 changes: 3 additions & 2 deletions stdlib/src/utils/numerics.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ struct FPUtils[
Returns:
The sign mask.
"""
# convert to `Int` first to bypass overflow check
return 1 << int(Self.exponent_width() + Self.mantissa_width())

@staticmethod
Expand All @@ -171,12 +172,12 @@ struct FPUtils[
fn exponent_mantissa_mask() -> Int:
"""Returns the exponent and mantissa mask of a floating point type.
It is computed by `exponent_mask + mantissa_mask`.
It is computed by `exponent_mask | mantissa_mask`.
Returns:
The exponent and mantissa mask.
"""
return Self.exponent_mask() + Self.mantissa_mask()
return Self.exponent_mask() | Self.mantissa_mask()

@staticmethod
@always_inline
Expand Down

0 comments on commit b99d031

Please sign in to comment.