From 3490eb5695bd8f7f3dac247be6b9344a33fa6caf Mon Sep 17 00:00:00 2001 From: David Plass Date: Wed, 22 Jan 2025 09:00:50 -0800 Subject: [PATCH] Support conditionals in typecheck v2. PiperOrigin-RevId: 718407100 --- .../type_system_v2/typecheck_module_v2.cc | 26 ++++ .../typecheck_module_v2_test.cc | 129 ++++++++++++++++++ 2 files changed, 155 insertions(+) diff --git a/xls/dslx/type_system_v2/typecheck_module_v2.cc b/xls/dslx/type_system_v2/typecheck_module_v2.cc index 8f9baaa570..3170ac1367 100644 --- a/xls/dslx/type_system_v2/typecheck_module_v2.cc +++ b/xls/dslx/type_system_v2/typecheck_module_v2.cc @@ -166,6 +166,32 @@ class PopulateInferenceTableVisitor : public AstNodeVisitorWithDefault { return DefaultHandler(node); } + absl::Status HandleConditional(const Conditional* node) override { + VLOG(5) << "HandleConditional: " << node->ToString(); + // In the example `const D = if (a) {b} else {c};`, the `ConstantDef` + // establishes a type variable that is just propagated down to `b` and + // `c` here, meaning that `b`, `c`, and the result must ultimately be + // the same type as 'D'. The test 'a' must be a bool, so we annotate it as + // such. + const NameRef* type_variable = *table_.GetTypeVariable(node); + XLS_RETURN_IF_ERROR( + table_.SetTypeVariable(node->consequent(), type_variable)); + XLS_RETURN_IF_ERROR( + table_.SetTypeVariable(ToAstNode(node->alternate()), type_variable)); + + // Mark the test as bool. + XLS_RETURN_IF_ERROR(table_.SetTypeAnnotation( + node->test(), CreateBoolAnnotation(module_, node->test()->span()))); + XLS_ASSIGN_OR_RETURN( + const NameRef* test_variable, + table_.DefineInternalVariable( + InferenceVariableKind::kType, const_cast(node->test()), + GenerateInternalTypeVariableName(node->test()))); + XLS_RETURN_IF_ERROR(table_.SetTypeVariable(node->test(), test_variable)); + + return DefaultHandler(node); + } + absl::Status HandleXlsTuple(const XlsTuple* node) override { VLOG(5) << "HandleXlsTuple: " << node->ToString(); diff --git a/xls/dslx/type_system_v2/typecheck_module_v2_test.cc b/xls/dslx/type_system_v2/typecheck_module_v2_test.cc index e6e21c8dca..1a093ab1ec 100644 --- a/xls/dslx/type_system_v2/typecheck_module_v2_test.cc +++ b/xls/dslx/type_system_v2/typecheck_module_v2_test.cc @@ -1719,5 +1719,134 @@ fn foo(x: u32, y: u32) -> bool { TypecheckFails(HasSizeMismatch("u32", "bool"))); } +TEST(TypecheckV2Test, IfType) { + EXPECT_THAT("const X = if true { u32:1 } else { u32:0 };", + TypecheckSucceeds(HasNodeWithType("X", "uN[32]"))); +} + +TEST(TypecheckV2Test, IfTypeMismatch) { + EXPECT_THAT("const X: u31 = if true { u32:1 } else { u32:0 };", + TypecheckFails(HasSizeMismatch("u32", "u31"))); +} + +TEST(TypecheckV2Test, IfTestVariable) { + EXPECT_THAT("const Y = true; const X = if Y { u32:1 } else { u32:0 };", + TypecheckSucceeds(HasNodeWithType("X", "uN[32]"))); +} + +TEST(TypecheckV2Test, IfTestVariableNotVariable) { + EXPECT_THAT("const Y = true; const X = if Y { Y } else { !Y };", + TypecheckSucceeds(HasNodeWithType("X", "uN[1]"))); +} + +TEST(TypecheckV2Test, IfTestVariables) { + EXPECT_THAT(R"( +const Y = true; +const Z = false; +const X = if (Y && Z) {u32:1} else { u32:2 }; +)", + TypecheckSucceeds(AllOf(HasNodeWithType("Y", "uN[1]"), + HasNodeWithType("Z", "uN[1]"), + HasNodeWithType("X", "uN[32]")))); +} + +TEST(TypecheckV2Test, IfTestBadVariable) { + EXPECT_THAT("const Y = u32:1; const X = if Y { u32:1 } else { u32:0 };", + TypecheckFails(HasSizeMismatch("u32", "bool"))); +} + +TEST(TypecheckV2Test, IfTestFnCall) { + EXPECT_THAT(R"( +fn f() -> bool { true } +const X = if f() { u32:1 } else { u32:0 }; +)", + TypecheckSucceeds(HasNodeWithType("X", "uN[32]"))); +} + +TEST(TypecheckV2Test, IfTestBadFnCall) { + EXPECT_THAT(R"( +fn f() -> u32 { u32:1 } +const X = if f() { u32:1 } else { u32:0 }; +)", + TypecheckFails(HasSizeMismatch("u32", "bool"))); +} + +TEST(TypecheckV2Test, FnReturnsIf) { + EXPECT_THAT(R"( +fn f(x:u10) -> u32 { if x>u10:0 { u32:1 } else { u32:0 } } +const X = f(u10:1); +)", + TypecheckSucceeds(HasNodeWithType("X", "uN[32]"))); +} + +TEST(TypecheckV2Test, CallFnWithIf) { + EXPECT_THAT(R"( +fn f(x:u32) -> u32 { x } +const X = f(if true { u32:1 } else { u32:0 }); +)", + TypecheckSucceeds(HasNodeWithType("X", "uN[32]"))); +} + +TEST(TypecheckV2Test, IfTestInt) { + EXPECT_THAT("const X = if u32:1 { u32:1 } else { u32:0 };", + TypecheckFails(HasSizeMismatch("u32", "bool"))); +} + +TEST(TypecheckV2Test, IfAlternativeWrongType) { + EXPECT_THAT("const X = if true { u32:1 } else { u31:0 };", + TypecheckFails(HasSizeMismatch("u31", "u32"))); +} + +TEST(TypecheckV2Test, IfElseIf) { + EXPECT_THAT(R"( +const X = if false { + u32:1 +} else if true { + u32:2 +} else { + u32:3 +};)", + TypecheckSucceeds(HasNodeWithType("X", "uN[32]"))); +} + +TEST(TypecheckV2Test, ElseIfMismatch) { + EXPECT_THAT(R"( +const X = if false { + u32:1 +} else if true { + u31:2 +} else { + u32:3 +};)", + TypecheckFails(HasSizeMismatch("u31", "u32"))); +} + +TEST(TypecheckV2Test, ElseIfNotBool) { + EXPECT_THAT(R"(const X = if false { + u32:1 +} else if u32:1 { + u32:2 +} else { + u32:3 +};)", + TypecheckFails(HasSizeMismatch("u32", "bool"))); +} + +TEST(TypecheckV2Test, IfParametricVariable) { + EXPECT_THAT(R"( +fn f(x: uN[N]) -> u32 { if true { N } else { N }} +const Y = f(u10:256); +)", + TypecheckSucceeds(HasNodeWithType("Y", "uN[32]"))); +} + +TEST(TypecheckV2Test, IfParametricType) { + EXPECT_THAT(R"( +fn f(x: uN[N]) -> uN[N] { if true { x } else { x }} +const Y = f(u10:256); +)", + TypecheckSucceeds(HasNodeWithType("Y", "uN[10]"))); +} + } // namespace } // namespace xls::dslx