Skip to content

Commit

Permalink
Support conditionals in typecheck v2.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 718407100
  • Loading branch information
dplassgit authored and copybara-github committed Jan 22, 2025
1 parent dcf204f commit 3490eb5
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 0 deletions.
26 changes: 26 additions & 0 deletions xls/dslx/type_system_v2/typecheck_module_v2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,32 @@ class PopulateInferenceTableVisitor : public AstNodeVisitorWithDefault {
return DefaultHandler(node);
}

absl::Status HandleConditional(const Conditional* node) override {
VLOG(5) << "HandleConditional: " << node->ToString();
// In the example `const D = if (a) {b} else {c};`, the `ConstantDef`
// establishes a type variable that is just propagated down to `b` and
// `c` here, meaning that `b`, `c`, and the result must ultimately be
// the same type as 'D'. The test 'a' must be a bool, so we annotate it as
// such.
const NameRef* type_variable = *table_.GetTypeVariable(node);
XLS_RETURN_IF_ERROR(
table_.SetTypeVariable(node->consequent(), type_variable));
XLS_RETURN_IF_ERROR(
table_.SetTypeVariable(ToAstNode(node->alternate()), type_variable));

// Mark the test as bool.
XLS_RETURN_IF_ERROR(table_.SetTypeAnnotation(
node->test(), CreateBoolAnnotation(module_, node->test()->span())));
XLS_ASSIGN_OR_RETURN(
const NameRef* test_variable,
table_.DefineInternalVariable(
InferenceVariableKind::kType, const_cast<Expr*>(node->test()),
GenerateInternalTypeVariableName(node->test())));
XLS_RETURN_IF_ERROR(table_.SetTypeVariable(node->test(), test_variable));

return DefaultHandler(node);
}

absl::Status HandleXlsTuple(const XlsTuple* node) override {
VLOG(5) << "HandleXlsTuple: " << node->ToString();

Expand Down
129 changes: 129 additions & 0 deletions xls/dslx/type_system_v2/typecheck_module_v2_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1719,5 +1719,134 @@ fn foo(x: u32, y: u32) -> bool {
TypecheckFails(HasSizeMismatch("u32", "bool")));
}

TEST(TypecheckV2Test, IfType) {
EXPECT_THAT("const X = if true { u32:1 } else { u32:0 };",
TypecheckSucceeds(HasNodeWithType("X", "uN[32]")));
}

TEST(TypecheckV2Test, IfTypeMismatch) {
EXPECT_THAT("const X: u31 = if true { u32:1 } else { u32:0 };",
TypecheckFails(HasSizeMismatch("u32", "u31")));
}

TEST(TypecheckV2Test, IfTestVariable) {
EXPECT_THAT("const Y = true; const X = if Y { u32:1 } else { u32:0 };",
TypecheckSucceeds(HasNodeWithType("X", "uN[32]")));
}

TEST(TypecheckV2Test, IfTestVariableNotVariable) {
EXPECT_THAT("const Y = true; const X = if Y { Y } else { !Y };",
TypecheckSucceeds(HasNodeWithType("X", "uN[1]")));
}

TEST(TypecheckV2Test, IfTestVariables) {
EXPECT_THAT(R"(
const Y = true;
const Z = false;
const X = if (Y && Z) {u32:1} else { u32:2 };
)",
TypecheckSucceeds(AllOf(HasNodeWithType("Y", "uN[1]"),
HasNodeWithType("Z", "uN[1]"),
HasNodeWithType("X", "uN[32]"))));
}

TEST(TypecheckV2Test, IfTestBadVariable) {
EXPECT_THAT("const Y = u32:1; const X = if Y { u32:1 } else { u32:0 };",
TypecheckFails(HasSizeMismatch("u32", "bool")));
}

TEST(TypecheckV2Test, IfTestFnCall) {
EXPECT_THAT(R"(
fn f() -> bool { true }
const X = if f() { u32:1 } else { u32:0 };
)",
TypecheckSucceeds(HasNodeWithType("X", "uN[32]")));
}

TEST(TypecheckV2Test, IfTestBadFnCall) {
EXPECT_THAT(R"(
fn f() -> u32 { u32:1 }
const X = if f() { u32:1 } else { u32:0 };
)",
TypecheckFails(HasSizeMismatch("u32", "bool")));
}

TEST(TypecheckV2Test, FnReturnsIf) {
EXPECT_THAT(R"(
fn f(x:u10) -> u32 { if x>u10:0 { u32:1 } else { u32:0 } }
const X = f(u10:1);
)",
TypecheckSucceeds(HasNodeWithType("X", "uN[32]")));
}

TEST(TypecheckV2Test, CallFnWithIf) {
EXPECT_THAT(R"(
fn f(x:u32) -> u32 { x }
const X = f(if true { u32:1 } else { u32:0 });
)",
TypecheckSucceeds(HasNodeWithType("X", "uN[32]")));
}

TEST(TypecheckV2Test, IfTestInt) {
EXPECT_THAT("const X = if u32:1 { u32:1 } else { u32:0 };",
TypecheckFails(HasSizeMismatch("u32", "bool")));
}

TEST(TypecheckV2Test, IfAlternativeWrongType) {
EXPECT_THAT("const X = if true { u32:1 } else { u31:0 };",
TypecheckFails(HasSizeMismatch("u31", "u32")));
}

TEST(TypecheckV2Test, IfElseIf) {
EXPECT_THAT(R"(
const X = if false {
u32:1
} else if true {
u32:2
} else {
u32:3
};)",
TypecheckSucceeds(HasNodeWithType("X", "uN[32]")));
}

TEST(TypecheckV2Test, ElseIfMismatch) {
EXPECT_THAT(R"(
const X = if false {
u32:1
} else if true {
u31:2
} else {
u32:3
};)",
TypecheckFails(HasSizeMismatch("u31", "u32")));
}

TEST(TypecheckV2Test, ElseIfNotBool) {
EXPECT_THAT(R"(const X = if false {
u32:1
} else if u32:1 {
u32:2
} else {
u32:3
};)",
TypecheckFails(HasSizeMismatch("u32", "bool")));
}

TEST(TypecheckV2Test, IfParametricVariable) {
EXPECT_THAT(R"(
fn f<N:u32>(x: uN[N]) -> u32 { if true { N } else { N }}
const Y = f(u10:256);
)",
TypecheckSucceeds(HasNodeWithType("Y", "uN[32]")));
}

TEST(TypecheckV2Test, IfParametricType) {
EXPECT_THAT(R"(
fn f<N:u32>(x: uN[N]) -> uN[N] { if true { x } else { x }}
const Y = f(u10:256);
)",
TypecheckSucceeds(HasNodeWithType("Y", "uN[10]")));
}

} // namespace
} // namespace xls::dslx

0 comments on commit 3490eb5

Please sign in to comment.