diff --git a/docs_src/tutorials/intro_to_parametrics.md b/docs_src/tutorials/intro_to_parametrics.md index 9a57d51ac5..b458f21f07 100644 --- a/docs_src/tutorials/intro_to_parametrics.md +++ b/docs_src/tutorials/intro_to_parametrics.md @@ -2,7 +2,8 @@ This tutorial demonstrates how types and functions can be parameterized to enable them to work on data of different formats and layouts, e.g., for a -function `foo` to work on both u16 and u32 data types, and anywhere in between. +function `foo` to work on both `u16` and `u32` data types, and anywhere in +between. It's recommended that you're familiar with the concepts in the previous tutorial, @@ -11,7 +12,7 @@ before following this tutorial. ## Simple parametrics -Consider the simple example of the `umax` function +Consider the simple example of a `umax` function -- similar to the `max` function [in the DSLX standard library](https://github.com/google/xls/tree/main/xls/dslx/stdlib/std.x): ```dslx @@ -40,10 +41,12 @@ infer them: Explicit specification: ```dslx -import std; +fn umax(x: uN[N], y: uN[N]) -> uN[N] { + if x > y { x } else { y } +} fn foo(a: u32, b: u16) -> u64 { - std::umax(a as u64, b as u64) + umax(a as u64, b as u64) } ``` @@ -53,10 +56,12 @@ are. Parametric inference: ```dslx -import std; +fn umax(x: uN[N], y: uN[N]) -> uN[N] { + if x > y { x } else { y } +} fn foo(a: u32, b: u16) -> u64 { - std::umax(a as u64, b as u64) + umax(a as u64, b as u64) } ``` diff --git a/xls/dslx/ir_convert/function_converter.cc b/xls/dslx/ir_convert/function_converter.cc index 09a7c8bb6c..13a28fcd4c 100644 --- a/xls/dslx/ir_convert/function_converter.cc +++ b/xls/dslx/ir_convert/function_converter.cc @@ -960,12 +960,15 @@ absl::Status FunctionConverter::HandleBuiltinCheckedCast( int64_t old_bit_count, std::get(input_bit_count_ctd.value()).GetBitValueViaSign()); - if (dynamic_cast(output_type.get()) != nullptr || - dynamic_cast(input_type.get()) != nullptr) { + std::optional output_bits_like = + GetBitsLike(*output_type); + std::optional input_bits_like = GetBitsLike(*input_type); + + if (!output_bits_like.has_value() || !input_bits_like.has_value()) { return IrConversionErrorStatus( node->span(), - absl::StrFormat("CheckedCast to and from array " - "is not currently supported for IR conversion; " + absl::StrFormat("CheckedCast is only supported for bits-like types in " + "IR conversion; " "attempted checked cast from: %s to: %s", input_type->ToString(), output_type->ToString()), file_table()); diff --git a/xls/dslx/stdlib/apfloat.x b/xls/dslx/stdlib/apfloat.x index 81785f94bb..acae738e94 100644 --- a/xls/dslx/stdlib/apfloat.x +++ b/xls/dslx/stdlib/apfloat.x @@ -1503,7 +1503,7 @@ fn test_fp_lt_2() { fn to_signed_or_unsigned_int (x: APFloat) -> xN[RESULT_SIGNED][RESULT_SZ] { const WIDE_FRACTION: u32 = FRACTION_SZ + u32:1; - const MAX_FRACTION_SZ: u32 = std::umax(RESULT_SZ, WIDE_FRACTION); + const MAX_FRACTION_SZ: u32 = std::max(RESULT_SZ, WIDE_FRACTION); const INT_MIN = if RESULT_SIGNED { (uN[MAX_FRACTION_SZ]:1 << (RESULT_SZ - u32:1)) // or rather, its negative. diff --git a/xls/dslx/stdlib/std.x b/xls/dslx/stdlib/std.x index a0af8664db..eeab426bba 100644 --- a/xls/dslx/stdlib/std.x +++ b/xls/dslx/stdlib/std.x @@ -90,74 +90,77 @@ fn unsigned_max_value_test() { assert_eq(u32:0xffffffff, unsigned_max_value()); } -// Returns the maximum of two signed integers. -pub fn smax(x: sN[N], y: sN[N]) -> sN[N] { if x > y { x } else { y } } +// Returns the maximum of two (signed or unsigned) integers. +pub fn max(x: xN[S][N], y: xN[S][N]) -> xN[S][N] { if x > y { x } else { y } } #[test] -fn smax_test() { - assert_eq(s2:0, smax(s2:0, s2:0)); - assert_eq(s2:1, smax(s2:-1, s2:1)); - assert_eq(s7:-3, smax(s7:-3, s7:-6)); +fn max_test_signed() { + assert_eq(s2:0, max(s2:0, s2:0)); + assert_eq(s2:1, max(s2:-1, s2:1)); + assert_eq(s7:-3, max(s7:-3, s7:-6)); } -// Returns the maximum of two unsigned integers. -pub fn umax(x: uN[N], y: uN[N]) -> uN[N] { if x > y { x } else { y } } - #[test] -fn umax_test() { - assert_eq(u1:1, umax(u1:1, u1:0)); - assert_eq(u1:1, umax(u1:1, u1:1)); - assert_eq(u2:3, umax(u2:3, u2:2)); +fn max_test_unsigned() { + assert_eq(u1:1, max(u1:1, u1:0)); + assert_eq(u1:1, max(u1:1, u1:1)); + assert_eq(u2:3, max(u2:3, u2:2)); } -// Returns the maximum of two signed integers. -pub fn smin(x: sN[N], y: sN[N]) -> sN[N] { if x < y { x } else { y } } +// Returns the minimum of two (signed or unsigned) integers. +pub fn min(x: xN[S][N], y: xN[S][N]) -> xN[S][N] { if x < y { x } else { y } } -#[test] -fn smin_test() { - assert_eq(s1:0, smin(s1:0, s1:0)); - assert_eq(s1:-1, smin(s1:0, s1:1)); - assert_eq(s1:-1, smin(s1:1, s1:0)); - assert_eq(s1:-1, smin(s1:1, s1:1)); +// TODO(meheff): Remove when all uses have been ported to std::min/std::max. +pub fn smax(x: sN[N], y: sN[N]) -> sN[N] { max(x, y) } - assert_eq(s2:-2, smin(s2:0, s2:-2)); - assert_eq(s2:-1, smin(s2:0, s2:-1)); - assert_eq(s2:0, smin(s2:0, s2:0)); - assert_eq(s2:0, smin(s2:0, s2:1)); +pub fn smin(x: sN[N], y: sN[N]) -> sN[N] { min(x, y) } - assert_eq(s2:-2, smin(s2:1, s2:-2)); - assert_eq(s2:-1, smin(s2:1, s2:-1)); - assert_eq(s2:0, smin(s2:1, s2:0)); - assert_eq(s2:1, smin(s2:1, s2:1)); +pub fn umax(x: uN[N], y: uN[N]) -> uN[N] { max(x, y) } - assert_eq(s2:-2, smin(s2:-2, s2:-2)); - assert_eq(s2:-2, smin(s2:-2, s2:-1)); - assert_eq(s2:-2, smin(s2:-2, s2:0)); - assert_eq(s2:-2, smin(s2:-2, s2:1)); +pub fn umin(x: uN[N], y: uN[N]) -> uN[N] { min(x, y) } - assert_eq(s2:-2, smin(s2:-1, s2:-2)); - assert_eq(s2:-1, smin(s2:-1, s2:-1)); - assert_eq(s2:-1, smin(s2:-1, s2:0)); - assert_eq(s2:-1, smin(s2:-1, s2:1)); +#[test] +fn min_test_unsigned() { + assert_eq(u1:0, min(u1:1, u1:0)); + assert_eq(u1:1, min(u1:1, u1:1)); + assert_eq(u2:2, min(u2:3, u2:2)); } -// Returns the minimum of two unsigned integers. -pub fn umin(x: uN[N], y: uN[N]) -> uN[N] { if x < y { x } else { y } } - #[test] -fn umin_test() { - assert_eq(u1:0, umin(u1:1, u1:0)); - assert_eq(u1:1, umin(u1:1, u1:1)); - assert_eq(u2:2, umin(u2:3, u2:2)); +fn min_test_signed() { + assert_eq(s1:0, min(s1:0, s1:0)); + assert_eq(s1:-1, min(s1:0, s1:1)); + assert_eq(s1:-1, min(s1:1, s1:0)); + assert_eq(s1:-1, min(s1:1, s1:1)); + + assert_eq(s2:-2, min(s2:0, s2:-2)); + assert_eq(s2:-1, min(s2:0, s2:-1)); + assert_eq(s2:0, min(s2:0, s2:0)); + assert_eq(s2:0, min(s2:0, s2:1)); + + assert_eq(s2:-2, min(s2:1, s2:-2)); + assert_eq(s2:-1, min(s2:1, s2:-1)); + assert_eq(s2:0, min(s2:1, s2:0)); + assert_eq(s2:1, min(s2:1, s2:1)); + + assert_eq(s2:-2, min(s2:-2, s2:-2)); + assert_eq(s2:-2, min(s2:-2, s2:-1)); + assert_eq(s2:-2, min(s2:-2, s2:0)); + assert_eq(s2:-2, min(s2:-2, s2:1)); + + assert_eq(s2:-2, min(s2:-1, s2:-2)); + assert_eq(s2:-1, min(s2:-1, s2:-1)); + assert_eq(s2:-1, min(s2:-1, s2:0)); + assert_eq(s2:-1, min(s2:-1, s2:1)); } // Returns unsigned add of x (N bits) and y (M bits) as a max(N,M)+1 bit value. -pub fn uadd(x: uN[N], y: uN[M]) -> uN[R] { +pub fn uadd(x: uN[N], y: uN[M]) -> uN[R] { (x as uN[R]) + (y as uN[R]) } // Returns signed add of x (N bits) and y (M bits) as a max(N,M)+1 bit value. -pub fn sadd(x: sN[N], y: sN[M]) -> sN[R] { +pub fn sadd(x: sN[N], y: sN[M]) -> sN[R] { (x as sN[R]) + (y as sN[R]) } @@ -773,7 +776,7 @@ fn test_to_unsigned() { // let result : (bool, u16) = uadd_with_overflow(x, y); // pub fn uadd_with_overflow - + (x: uN[N], y: uN[M]) -> (bool, uN[V]) { let x_extended = widening_cast(x); @@ -801,47 +804,48 @@ fn test_uadd_with_overflow() { } // Extract bits given a fixed-point integer with a constant offset. -// i.e. let x_extended = x as uN[max(unsigned_sizeof(x) + fixed_shift, to_exclusive)]; -// (x_extended << fixed_shift)[from_inclusive:to_exclusive] +// i.e. let x_extended = x as uN[max(unsigned_sizeof(x) + FIXED_SHIFT, TO_EXCLUSIVE)]; +// (x_extended << FIXED_SHIFT)[FROM_INCLUSIVE:TO_EXCLUSIVE] // // This function behaves as-if x has reasonably infinite precision so that -// the result is zero-padded if from_inclusive or to_exclusive are out of +// the result is zero-padded if FROM_INCLUSIVE or TO_EXCLUSIVE are out of // range of the original x's bitwidth. // -// If to_exclusive <= from_exclusive, the result will be a zero-bit uN[0]. +// If TO_EXCLUSIVE <= FROM_INCLUSIVE, the result will be a zero-bit uN[0]. pub fn extract_bits - - (x: uN[N]) -> uN[extract_width] { - if to_exclusive <= from_inclusive { - uN[extract_width]:0 + + (x: uN[N]) -> uN[EXTRACT_WIDTH] { + if TO_EXCLUSIVE <= FROM_INCLUSIVE { + uN[EXTRACT_WIDTH]:0 } else { // With a non-zero fixed width, all lower bits of index < fixed_shift are // are zero. let lower_bits = - uN[checked_cast(smax(s32:0, fixed_shift as s32 - from_inclusive as s32))]:0; + uN[checked_cast(max(s32:0, FIXED_SHIFT as s32 - FROM_INCLUSIVE as s32))]:0; // Based on the input of N bits and a fixed shift, there are an effective // count of N + fixed_shift known bits. All bits of index > // N + fixed_shift - 1 are zero's. const UPPER_BIT_COUNT = checked_cast( - smax(s32:0, N as s32 + fixed_shift as s32 - to_exclusive as s32 - s32:1)); - let upper_bits = uN[UPPER_BIT_COUNT]:0; + max(s32:0, N as s32 + FIXED_SHIFT as s32 - TO_EXCLUSIVE as s32 - s32:1)); + const UPPER_BITS = uN[UPPER_BIT_COUNT]:0; - if fixed_shift < from_inclusive { + if FIXED_SHIFT < FROM_INCLUSIVE { // The bits extracted start within or after the middle span. // upper_bits ++ middle_bits - let middle_bits = upper_bits ++ - x[smin(from_inclusive as s32 - fixed_shift as s32, N as s32) - :smin(to_exclusive as s32 - fixed_shift as s32, N as s32)]; - (upper_bits ++ middle_bits) as uN[extract_width] - } else if fixed_shift <= to_exclusive { + const FROM: s32 = min(FROM_INCLUSIVE as s32 - FIXED_SHIFT as s32, N as s32); + const TO: s32 = min(TO_EXCLUSIVE as s32 - FIXED_SHIFT as s32, N as s32); + let middle_bits = UPPER_BITS ++ x[FROM:TO]; + (UPPER_BITS ++ middle_bits) as uN[EXTRACT_WIDTH] + } else if FIXED_SHIFT <= TO_EXCLUSIVE { // The bits extracted start within the fixed_shift span. - let middle_bits = x[0:smin(to_exclusive as s32 - fixed_shift as s32, N as s32)]; + const TO: s32 = min(TO_EXCLUSIVE as s32 - FIXED_SHIFT as s32, N as s32); + let middle_bits = x[0:TO]; - (upper_bits ++ middle_bits ++ lower_bits) as uN[extract_width] + (UPPER_BITS ++ middle_bits ++ lower_bits) as uN[EXTRACT_WIDTH] } else { - uN[extract_width]:0 + uN[EXTRACT_WIDTH]:0 } } } @@ -928,7 +932,7 @@ pub fn umul_with_overflow > u32:1}, N_upper_bits: u32 = {N - N_lower_bits}, M_lower_bits: u32 = {M >> u32:1}, M_upper_bits: u32 = {M - M_lower_bits}, - Min_N_M_lower_bits: u32 = {umin(N_lower_bits, M_lower_bits)}, N_Plus_M: u32 = {N + M}> + Min_N_M_lower_bits: u32 = {min(N_lower_bits, M_lower_bits)}, N_Plus_M: u32 = {N + M}> (x: uN[N], y: uN[M]) -> (bool, uN[V]) { // Break x and y into two halves. // x = x1 ++ x0, diff --git a/xls/dslx/tests/BUILD b/xls/dslx/tests/BUILD index 18ff9fbe86..12940d513e 100644 --- a/xls/dslx/tests/BUILD +++ b/xls/dslx/tests/BUILD @@ -185,6 +185,10 @@ dslx_lang_test(name = "xn_type_equivalence") dslx_lang_test(name = "xn_signedness_properties") +dslx_lang_test(name = "xn_slice_bounds") + +dslx_lang_test(name = "xn_widening_cast") + dslx_lang_test( name = "parametric_shift", # TODO(leary): 2023-08-14 Runs into "cannot translate zero length bitvector @@ -967,7 +971,7 @@ dslx_lang_test( ) xls_dslx_opt_ir( - name = "mod_const_importer", + name = "mod_const_importer_opt_ir", srcs = ["mod_const_importer.x"], dslx_top = "main", deps = [":mod_simple_const_dslx"], @@ -975,7 +979,7 @@ xls_dslx_opt_ir( xls_dslx_opt_ir_test( name = "mod_const_importer_test", - dep = ":mod_const_importer", + dep = ":mod_const_importer_opt_ir", ) dslx_lang_test( diff --git a/xls/dslx/tests/errors/error_modules_test.py b/xls/dslx/tests/errors/error_modules_test.py index 0c2efd9c19..a97f496233 100644 --- a/xls/dslx/tests/errors/error_modules_test.py +++ b/xls/dslx/tests/errors/error_modules_test.py @@ -873,8 +873,8 @@ def test_equals_rhs_undefined_nameref(self): def test_umin_type_mismatch(self): stderr = self._run('xls/dslx/tests/errors/umin_type_mismatch.x') - self.assertIn('umin_type_mismatch.x:21:12-21:27', stderr) - self.assertIn('XlsTypeError: uN[N] vs uN[8]', stderr) + self.assertIn('umin_type_mismatch.x:21:13-21:28', stderr) + self.assertIn('saw: 42; then: 8', stderr) def test_diag_block_with_trailing_semi(self): stderr = self._run( diff --git a/xls/dslx/tests/errors/spawn_wrong_argc.x b/xls/dslx/tests/errors/spawn_wrong_argc.x index d3ec922987..8b2242c4d7 100644 --- a/xls/dslx/tests/errors/spawn_wrong_argc.x +++ b/xls/dslx/tests/errors/spawn_wrong_argc.x @@ -19,7 +19,7 @@ pub proc foo { config () { () } next(state: ()) { - std::umin(u32:1, u32:2); + std::min(u32:1, u32:2); () } } @@ -37,7 +37,7 @@ proc test_case { } next(state: ()) { - std::umin(u32:1, u32:2); + std::min(u32:1, u32:2); let tok = send(join(), terminator, true); () } diff --git a/xls/dslx/tests/errors/umin_type_mismatch.x b/xls/dslx/tests/errors/umin_type_mismatch.x index 6ebfa0f207..129f5f2dc0 100644 --- a/xls/dslx/tests/errors/umin_type_mismatch.x +++ b/xls/dslx/tests/errors/umin_type_mismatch.x @@ -18,5 +18,5 @@ const MY_U32 = u42:42; const MY_U8 = u8:42; fn f() -> u32 { - std::umin(MY_U32, MY_U8) + std::min(MY_U32, MY_U8) } diff --git a/xls/dslx/tests/xn_slice_bounds.x b/xls/dslx/tests/xn_slice_bounds.x new file mode 100644 index 0000000000..a584121e5b --- /dev/null +++ b/xls/dslx/tests/xn_slice_bounds.x @@ -0,0 +1,33 @@ +// Copyright 2025 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +const S = true; +const N = u32:32; + +type MyS32 = xN[S][N]; + +fn from_to(x: u32) -> u8 { x[MyS32:0:MyS32:8] } + +fn to(x: u32) -> u8 { x[:MyS32:8] } + +fn from(x: u32) -> u8 { x[MyS32:-8:] } + +fn main(x: u32) -> u8[3] { [from_to(x), to(x), from(x)] } + +#[test] +fn test_main() { + assert_eq(from_to(u32:0x12345678), u8:0x78); + assert_eq(to(u32:0x12345678), u8:0x78); + assert_eq(from(u32:0x12345678), u8:0x12); +} diff --git a/xls/dslx/tests/xn_widening_cast.x b/xls/dslx/tests/xn_widening_cast.x new file mode 100644 index 0000000000..ef06644bc9 --- /dev/null +++ b/xls/dslx/tests/xn_widening_cast.x @@ -0,0 +1,17 @@ +// Copyright 2025 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import std; + +fn main(x: u7, y: u7) -> u32 { widening_cast(std::max(x, y)) } diff --git a/xls/dslx/type_system/deduce.cc b/xls/dslx/type_system/deduce.cc index 0764118410..628852d18c 100644 --- a/xls/dslx/type_system/deduce.cc +++ b/xls/dslx/type_system/deduce.cc @@ -1244,48 +1244,52 @@ static absl::StatusOr> DeduceSliceType( ctx->type_info()->AddSliceStartAndWidth(slice, fn_parametric_env, saw); // Make sure the start and end types match and that the limit fits. - std::unique_ptr start_type; - std::unique_ptr limit_type; + std::optional start_bits_like; + std::optional limit_bits_like; if (slice->start() == nullptr && slice->limit() == nullptr) { - start_type = BitsType::MakeS32(); - limit_type = BitsType::MakeS32(); + start_bits_like.emplace( + BitsLikeProperties{.is_signed = TypeDim::CreateBool(true), + .size = TypeDim::CreateU32(32)}); + limit_bits_like.emplace( + BitsLikeProperties{.is_signed = TypeDim::CreateBool(true), + .size = TypeDim::CreateU32(32)}); } else if (slice->start() != nullptr && slice->limit() == nullptr) { - XLS_ASSIGN_OR_RETURN(BitsType * tmp, - ctx->type_info()->GetItemAs(slice->start())); - start_type = tmp->CloneToUnique(); - limit_type = start_type->CloneToUnique(); + std::optional start_type = ctx->type_info()->GetItem(slice->start()); + XLS_RET_CHECK(start_type.has_value()); + start_bits_like = GetBitsLike(*start_type.value()); + limit_bits_like.emplace(Clone(start_bits_like.value())); } else if (slice->start() == nullptr && slice->limit() != nullptr) { - XLS_ASSIGN_OR_RETURN(BitsType * tmp, - ctx->type_info()->GetItemAs(slice->limit())); - limit_type = tmp->CloneToUnique(); - start_type = limit_type->CloneToUnique(); + std::optional limit_type = ctx->type_info()->GetItem(slice->limit()); + XLS_RET_CHECK(limit_type.has_value()); + limit_bits_like = GetBitsLike(*limit_type.value()); + start_bits_like.emplace(Clone(limit_bits_like.value())); } else { - XLS_ASSIGN_OR_RETURN(BitsType * tmp, - ctx->type_info()->GetItemAs(slice->start())); - start_type = tmp->CloneToUnique(); - XLS_ASSIGN_OR_RETURN(tmp, - ctx->type_info()->GetItemAs(slice->limit())); - limit_type = tmp->CloneToUnique(); + std::optional start_type = ctx->type_info()->GetItem(slice->start()); + XLS_RET_CHECK(start_type.has_value()); + start_bits_like = GetBitsLike(*start_type.value()); + + std::optional limit_type = ctx->type_info()->GetItem(slice->limit()); + XLS_RET_CHECK(limit_type.has_value()); + limit_bits_like = GetBitsLike(*limit_type.value()); } - if (*start_type != *limit_type) { + if (*start_bits_like != *limit_bits_like) { return TypeInferenceErrorStatus( - node->span(), limit_type.get(), + node->span(), nullptr, absl::StrFormat( "Slice limit type (%s) did not match slice start type (%s).", - limit_type->ToString(), start_type->ToString()), + ToTypeString(*limit_bits_like), ToTypeString(*start_bits_like)), ctx->file_table()); } - XLS_ASSIGN_OR_RETURN(TypeDim type_width_dim, start_type->GetTotalBitCount()); + const TypeDim& type_width_dim = start_bits_like->size; XLS_ASSIGN_OR_RETURN(int64_t type_width, type_width_dim.GetAsInt64()); if (Bits::MinBitCountSigned(saw.start + saw.width) > type_width) { return TypeInferenceErrorStatus( - node->span(), limit_type.get(), + node->span(), nullptr, absl::StrFormat("Slice limit does not fit in index type: %d.", saw.start + saw.width), ctx->file_table()); } - return std::make_unique(/*signed=*/false, saw.width); } diff --git a/xls/dslx/type_system/type.cc b/xls/dslx/type_system/type.cc index 97d1171e69..a63e28fed2 100644 --- a/xls/dslx/type_system/type.cc +++ b/xls/dslx/type_system/type.cc @@ -1117,6 +1117,16 @@ bool IsBitsLike(const Type& t) { IsArrayOfBitsConstructor(t); } +std::string ToTypeString(const BitsLikeProperties& properties) { + if (properties.is_signed.IsParametric()) { + return absl::StrFormat("xN[%s][%s]", properties.is_signed.ToString(), + properties.size.ToString()); + } + bool is_signed = properties.is_signed.GetAsBool().value(); + return absl::StrFormat("%sN[%s]", is_signed ? "s" : "u", + properties.size.ToString()); +} + std::optional GetBitsLike(const Type& t) { if (auto* bits_type = dynamic_cast(&t); bits_type != nullptr) { diff --git a/xls/dslx/type_system/type.h b/xls/dslx/type_system/type.h index 91cd44a4e2..4d1fd75ea6 100644 --- a/xls/dslx/type_system/type.h +++ b/xls/dslx/type_system/type.h @@ -608,6 +608,8 @@ class StructTypeBase : public Type { // things like type comparisons class StructType : public StructTypeBase { public: + static std::string GetDebugName() { return "StructType"; } + StructType(std::vector> members, const StructDef& struct_def, absl::flat_hash_map @@ -751,6 +753,8 @@ class TupleType : public Type { // These will nest in the case of multidimensional arrays. class ArrayType : public Type { public: + static std::string GetDebugName() { return "ArrayType"; } + ArrayType(std::unique_ptr element_type, const TypeDim& size); absl::Status Accept(TypeVisitor& v) const override { @@ -878,6 +882,8 @@ class BitsConstructorType : public Type { // respectively. class BitsType : public Type { public: + static std::string GetDebugName() { return "BitsType"; } + static std::unique_ptr MakeU64() { return std::make_unique(false, 64); } @@ -941,6 +947,8 @@ class BitsType : public Type { // Represents a function type with params and a return type. class FunctionType : public Type { public: + static std::string GetDebugName() { return "FunctionType"; } + FunctionType(std::vector> params, std::unique_ptr return_type) : params_(std::move(params)), return_type_(std::move(return_type)) { @@ -1106,6 +1114,15 @@ struct BitsLikeProperties { TypeDim size; }; +// Returns a string representation of the BitsLikeProperties that looks similar +// to a corresponding BitsType. +std::string ToTypeString(const BitsLikeProperties& properties); + +inline BitsLikeProperties Clone(const BitsLikeProperties& properties) { + return BitsLikeProperties{.is_signed = properties.is_signed.Clone(), + .size = properties.size.Clone()}; +} + inline bool operator==(const BitsLikeProperties& a, const BitsLikeProperties& b) { return a.is_signed == b.is_signed && a.size == b.size; diff --git a/xls/dslx/type_system/type_info.h b/xls/dslx/type_system/type_info.h index 9938ccdc31..f4b692a708 100644 --- a/xls/dslx/type_system/type_info.h +++ b/xls/dslx/type_system/type_info.h @@ -30,6 +30,7 @@ #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/types/variant.h" +#include "xls/common/status/ret_check.h" #include "xls/dslx/frontend/ast.h" #include "xls/dslx/frontend/pos.h" #include "xls/dslx/interp_value.h" @@ -403,6 +404,9 @@ class TypeInfo { template inline absl::StatusOr TypeInfo::GetItemAs(const AstNode* key) const { + static_assert(std::is_base_of::value, + "T must be a subclass of Type"); + std::optional t = GetItem(key); if (!t.has_value()) { return absl::NotFoundError( @@ -411,11 +415,11 @@ inline absl::StatusOr TypeInfo::GetItemAs(const AstNode* key) const { } DCHECK(t.value() != nullptr); auto* target = dynamic_cast(t.value()); - if (target == nullptr) { - return absl::FailedPreconditionError(absl::StrFormat( - "AST node (%s) @ %s did not have expected Type subtype.", - key->GetNodeTypeName(), SpanToString(key->GetSpan(), file_table()))); - } + XLS_RET_CHECK(target != nullptr) << absl::StreamFormat( + "AST node `%s` @ %s did not have expected `xls::dslx::Type` subtype; " + "want: %s got: %s", + key->ToString(), SpanToString(key->GetSpan(), file_table()), + T::GetDebugName(), t.value()->GetDebugTypeName()); return target; } diff --git a/xls/dslx/type_system/typecheck_invocation.cc b/xls/dslx/type_system/typecheck_invocation.cc index 197100ecde..b16797d97b 100644 --- a/xls/dslx/type_system/typecheck_invocation.cc +++ b/xls/dslx/type_system/typecheck_invocation.cc @@ -98,10 +98,12 @@ static absl::Status TypecheckIsAcceptableWideningCast(DeduceCtx* ctx, XLS_RET_CHECK(maybe_from_type.has_value()); XLS_RET_CHECK(maybe_to_type.has_value()); - BitsType* from = dynamic_cast(maybe_from_type.value()); - BitsType* to = dynamic_cast(maybe_to_type.value()); + std::optional from_bits_like = + GetBitsLike(*maybe_from_type.value()); + std::optional to_bits_like = + GetBitsLike(*maybe_to_type.value()); - if (from == nullptr || to == nullptr) { + if (!from_bits_like.has_value() || !to_bits_like.has_value()) { return ctx->TypeMismatchError( node->span(), from_expr, *maybe_from_type.value(), node, *maybe_to_type.value(), @@ -110,13 +112,13 @@ static absl::Status TypecheckIsAcceptableWideningCast(DeduceCtx* ctx, maybe_to_type.value()->ToErrorString())); } - bool signed_input = from->is_signed(); - bool signed_output = to->is_signed(); + XLS_ASSIGN_OR_RETURN(bool signed_input, + from_bits_like->is_signed.GetAsBool()); + XLS_ASSIGN_OR_RETURN(bool signed_output, to_bits_like->is_signed.GetAsBool()); XLS_ASSIGN_OR_RETURN(int64_t old_bit_count, - from->GetTotalBitCount().value().GetAsInt64()); - XLS_ASSIGN_OR_RETURN(int64_t new_bit_count, - to->GetTotalBitCount().value().GetAsInt64()); + from_bits_like->size.GetAsInt64()); + XLS_ASSIGN_OR_RETURN(int64_t new_bit_count, to_bits_like->size.GetAsInt64()); bool can_cast = ((signed_input == signed_output) && (new_bit_count >= old_bit_count)) || @@ -128,8 +130,8 @@ static absl::Status TypecheckIsAcceptableWideningCast(DeduceCtx* ctx, *maybe_to_type.value(), absl::StrFormat("Can not cast from type %s (%d bits) to" " %s (%d bits) with widening_cast", - from->ToString(), old_bit_count, to->ToString(), - new_bit_count)); + ToTypeString(from_bits_like.value()), old_bit_count, + ToTypeString(to_bits_like.value()), new_bit_count)); } return absl::OkStatus(); diff --git a/xls/modules/zstd/dec_mux.x b/xls/modules/zstd/dec_mux.x index 59778ff304..a877cc4aa5 100644 --- a/xls/modules/zstd/dec_mux.x +++ b/xls/modules/zstd/dec_mux.x @@ -117,7 +117,7 @@ pub proc DecoderMux { let all_valid = state.raw_data_valid && state.rle_data_valid && state.compressed_data_valid; let state = if (any_valid) { - let min_id = std::umin(std::umin(rle_id, raw_id), compressed_id); + let min_id = std::min(std::min(rle_id, raw_id), compressed_id); trace_fmt!("DecoderMux: rle_id: {}, raw_id: {}, compressed_id: {}", rle_id, raw_id, compressed_id); trace_fmt!("DecoderMux: min_id: {}", min_id); diff --git a/xls/modules/zstd/memory/axi_reader.x b/xls/modules/zstd/memory/axi_reader.x index 02ea24aff7..43504eea30 100644 --- a/xls/modules/zstd/memory/axi_reader.x +++ b/xls/modules/zstd/memory/axi_reader.x @@ -139,7 +139,7 @@ pub proc AxiReader< let bytes_to_max_burst = MAX_AXI_BURST_BYTES - aligned_offset as Length; let bytes_to_4k = common::bytes_to_4k_boundary(state.tran_addr); - let tran_len = std::umin(state.tran_len, std::umin(bytes_to_4k, bytes_to_max_burst)); + let tran_len = std::min(state.tran_len, std::min(bytes_to_4k, bytes_to_max_burst)); let (req_low_lane, req_high_lane) = common::get_lanes(state.tran_addr, tran_len); let adjusted_tran_len = aligned_offset as Addr + tran_len; diff --git a/xls/modules/zstd/memory/axi_writer.x b/xls/modules/zstd/memory/axi_writer.x index 982c444bcf..2f62307731 100644 --- a/xls/modules/zstd/memory/axi_writer.x +++ b/xls/modules/zstd/memory/axi_writer.x @@ -164,7 +164,7 @@ pub proc AxiWriter< } }, Fsm::TRANSFER_LENGTH => { - let tran_len = std::umin(state.transfer_data.length, std::umin(state.bytes_to_4k, state.bytes_to_max_axi_burst)); + let tran_len = std::min(state.transfer_data.length, std::min(state.bytes_to_4k, state.bytes_to_max_axi_burst)); State { fsm: Fsm::CALC_NEXT_TRANSFER, transaction_len: tran_len, diff --git a/xls/modules/zstd/sequence_executor.x b/xls/modules/zstd/sequence_executor.x index 422185d236..a1fea91d50 100644 --- a/xls/modules/zstd/sequence_executor.x +++ b/xls/modules/zstd/sequence_executor.x @@ -482,7 +482,7 @@ fn sequence_packet_to_read_reqs -> (ram::ReadReq[RAM_NUM], RamOrder[RAM_NUM], SequenceExecutorPacket, bool) { type ReadReq = ram::ReadReq; - let max_len = std::umin(seq.length as u32, std::umin(RAM_NUM, hb_len)); + let max_len = std::min(seq.length as u32, std::min(RAM_NUM, hb_len)); let (next_seq, next_seq_valid) = if seq.length > max_len as CopyOrMatchLength { ( diff --git a/xls/modules/zstd/zstd_dec.x b/xls/modules/zstd/zstd_dec.x index 259361de8f..0f9fac906e 100644 --- a/xls/modules/zstd/zstd_dec.x +++ b/xls/modules/zstd/zstd_dec.x @@ -179,7 +179,7 @@ fn feed_block_decoder(state: ZstdDecoderState) -> (bool, BlockDataPacket, ZstdDe trace_fmt!("zstd_dec: feed_block_decoder: buffer_length_bytes: {}", buffer_length_bytes); let data_width_bytes = (DATA_WIDTH >> 3) as BlockSize; trace_fmt!("zstd_dec: feed_block_decoder: data_width_bytes: {}", data_width_bytes); - let remaining_bytes_to_send_now = std::umin(remaining_bytes_to_send, data_width_bytes); + let remaining_bytes_to_send_now = std::min(remaining_bytes_to_send, data_width_bytes); trace_fmt!("zstd_dec: feed_block_decoder: remaining_bytes_to_send_now: {}", remaining_bytes_to_send_now); if (buffer_length_bytes >= remaining_bytes_to_send_now as u32) { let remaining_bits_to_send_now = (remaining_bytes_to_send_now as u32) << 3;