Skip to content

Commit

Permalink
gccrs: add support for lang_item eq and PartialEq trait
Browse files Browse the repository at this point in the history
The Eq and Partial Ord are very similar to the operator overloads
we support for add/sub/etc... but they differ in that usually the
function call name matches the name of the lang item. This time
we need to have support to send in a new path for the method call
on the lang item we want instead of just the name of the lang item.

NOTE: this test case doesnt work correctly yet we need to support
the derive of partial eq on enums to generate the correct comparison
code for that.

Fixes #3302

gcc/rust/ChangeLog:

	* backend/rust-compile-expr.cc (CompileExpr::visit): handle partial_eq possible call
	* backend/rust-compile-expr.h: handle case where lang item calls differ from name
	* hir/tree/rust-hir-expr.cc (OperatorExprMeta::OperatorExprMeta): new helper
	* hir/tree/rust-hir-expr.h: likewise
	* typecheck/rust-hir-type-check-expr.cc (TypeCheckExpr::visit): handle partial_eq
	(TypeCheckExpr::resolve_operator_overload): likewise
	* typecheck/rust-hir-type-check-expr.h: likewise
	* util/rust-lang-item.cc (LangItem::ComparisonToLangItem): map comparison to lang item
	(LangItem::ComparisonToSegment): likewise
	* util/rust-lang-item.h: new lang items PartialOrd and Eq
	* util/rust-operators.h (enum class): likewise

gcc/testsuite/ChangeLog:

	* rust/compile/cmp1.rs: New test.

Signed-off-by: Philip Herron <herron.philip@googlemail.com>
  • Loading branch information
philberty committed Jan 6, 2025
1 parent 26d3103 commit c24c1a8
Show file tree
Hide file tree
Showing 10 changed files with 192 additions and 13 deletions.
28 changes: 26 additions & 2 deletions gcc/rust/backend/rust-compile-expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,26 @@ CompileExpr::visit (HIR::ComparisonExpr &expr)
auto rhs = CompileExpr::Compile (expr.get_rhs (), ctx);
auto location = expr.get_locus ();

// this might be an operator overload situation lets check
TyTy::FnType *fntype;
bool is_op_overload = ctx->get_tyctx ()->lookup_operator_overload (
expr.get_mappings ().get_hirid (), &fntype);
if (is_op_overload)
{
auto seg_name = LangItem::ComparisonToSegment (expr.get_expr_type ());
auto segment = HIR::PathIdentSegment (seg_name);
auto lang_item_type
= LangItem::ComparisonToLangItem (expr.get_expr_type ());

rhs = address_expression (rhs, EXPR_LOCATION (rhs));

translated = resolve_operator_overload (
lang_item_type, expr, lhs, rhs, expr.get_lhs (),
tl::optional<std::reference_wrapper<HIR::Expr>> (expr.get_rhs ()),
segment);
return;
}

translated = Backend::comparison_expression (op, lhs, rhs, location);
}

Expand Down Expand Up @@ -1478,7 +1498,8 @@ CompileExpr::get_receiver_from_dyn (const TyTy::DynamicObjectType *dyn,
tree
CompileExpr::resolve_operator_overload (
LangItem::Kind lang_item_type, HIR::OperatorExprMeta expr, tree lhs, tree rhs,
HIR::Expr &lhs_expr, tl::optional<std::reference_wrapper<HIR::Expr>> rhs_expr)
HIR::Expr &lhs_expr, tl::optional<std::reference_wrapper<HIR::Expr>> rhs_expr,
HIR::PathIdentSegment specified_segment)
{
TyTy::FnType *fntype;
bool is_op_overload = ctx->get_tyctx ()->lookup_operator_overload (
Expand All @@ -1499,7 +1520,10 @@ CompileExpr::resolve_operator_overload (
}

// lookup compiled functions since it may have already been compiled
HIR::PathIdentSegment segment_name (LangItem::ToString (lang_item_type));
HIR::PathIdentSegment segment_name
= specified_segment.is_error ()
? HIR::PathIdentSegment (LangItem::ToString (lang_item_type))
: specified_segment;
tree fn_expr = resolve_method_address (fntype, receiver, expr.get_locus ());

// lookup the autoderef mappings
Expand Down
4 changes: 3 additions & 1 deletion gcc/rust/backend/rust-compile-expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ class CompileExpr : private HIRCompileBase, protected HIR::HIRExpressionVisitor
tree resolve_operator_overload (
LangItem::Kind lang_item_type, HIR::OperatorExprMeta expr, tree lhs,
tree rhs, HIR::Expr &lhs_expr,
tl::optional<std::reference_wrapper<HIR::Expr>> rhs_expr);
tl::optional<std::reference_wrapper<HIR::Expr>> rhs_expr,
HIR::PathIdentSegment specified_segment
= HIR::PathIdentSegment::create_error ());

tree compile_bool_literal (const HIR::LiteralExpr &expr,
const TyTy::BaseType *tyty);
Expand Down
6 changes: 6 additions & 0 deletions gcc/rust/hir/tree/rust-hir-expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1298,6 +1298,12 @@ OperatorExprMeta::OperatorExprMeta (HIR::ArrayIndexExpr &expr)
locus (expr.get_locus ())
{}

OperatorExprMeta::OperatorExprMeta (HIR::ComparisonExpr &expr)
: node_mappings (expr.get_mappings ()),
lvalue_mappings (expr.get_expr ().get_mappings ()),
locus (expr.get_locus ())
{}

AnonConst::AnonConst (NodeId id, std::unique_ptr<Expr> expr)
: id (id), expr (std::move (expr))
{
Expand Down
2 changes: 2 additions & 0 deletions gcc/rust/hir/tree/rust-hir-expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -2816,6 +2816,8 @@ class OperatorExprMeta

OperatorExprMeta (HIR::ArrayIndexExpr &expr);

OperatorExprMeta (HIR::ComparisonExpr &expr);

const Analysis::NodeMapping &get_mappings () const { return node_mappings; }

const Analysis::NodeMapping &get_lvalue_mappings () const
Expand Down
27 changes: 22 additions & 5 deletions gcc/rust/typecheck/rust-hir-type-check-expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,21 @@ TypeCheckExpr::visit (HIR::ComparisonExpr &expr)
auto lhs = TypeCheckExpr::Resolve (expr.get_lhs ());
auto rhs = TypeCheckExpr::Resolve (expr.get_rhs ());

auto borrwed_rhs
= new TyTy::ReferenceType (mappings.get_next_hir_id (),
TyTy::TyVar (rhs->get_ref ()), Mutability::Imm);
context->insert_implicit_type (borrwed_rhs->get_ref (), borrwed_rhs);

auto seg_name = LangItem::ComparisonToSegment (expr.get_expr_type ());
auto segment = HIR::PathIdentSegment (seg_name);
auto lang_item_type = LangItem::ComparisonToLangItem (expr.get_expr_type ());

bool operator_overloaded
= resolve_operator_overload (lang_item_type, expr, lhs, borrwed_rhs,
segment);
if (operator_overloaded)
return;

unify_site (expr.get_mappings ().get_hirid (),
TyTy::TyWithLocation (lhs, expr.get_lhs ().get_locus ()),
TyTy::TyWithLocation (rhs, expr.get_rhs ().get_locus ()),
Expand Down Expand Up @@ -1638,10 +1653,10 @@ TypeCheckExpr::visit (HIR::ClosureExpr &expr)
}

bool
TypeCheckExpr::resolve_operator_overload (LangItem::Kind lang_item_type,
HIR::OperatorExprMeta expr,
TyTy::BaseType *lhs,
TyTy::BaseType *rhs)
TypeCheckExpr::resolve_operator_overload (
LangItem::Kind lang_item_type, HIR::OperatorExprMeta expr,
TyTy::BaseType *lhs, TyTy::BaseType *rhs,
HIR::PathIdentSegment specified_segment)
{
// look up lang item for arithmetic type
std::string associated_item_name = LangItem::ToString (lang_item_type);
Expand All @@ -1659,7 +1674,9 @@ TypeCheckExpr::resolve_operator_overload (LangItem::Kind lang_item_type,
current_context = context->peek_context ();
}

auto segment = HIR::PathIdentSegment (associated_item_name);
auto segment = specified_segment.is_error ()
? HIR::PathIdentSegment (associated_item_name)
: specified_segment;
auto candidates = MethodResolver::Probe (lhs, segment);

// remove any recursive candidates
Expand Down
4 changes: 3 additions & 1 deletion gcc/rust/typecheck/rust-hir-type-check-expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ class TypeCheckExpr : private TypeCheckBase, private HIR::HIRExpressionVisitor
protected:
bool resolve_operator_overload (LangItem::Kind lang_item_type,
HIR::OperatorExprMeta expr,
TyTy::BaseType *lhs, TyTy::BaseType *rhs);
TyTy::BaseType *lhs, TyTy::BaseType *rhs,
HIR::PathIdentSegment specified_segment
= HIR::PathIdentSegment::create_error ());

bool resolve_fn_trait_call (HIR::CallExpr &expr,
TyTy::BaseType *function_tyty,
Expand Down
43 changes: 43 additions & 0 deletions gcc/rust/util/rust-lang-item.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ const BiMap<std::string, LangItem::Kind> Rust::LangItem::lang_items = {{

{"into_iter", Kind::INTOITER_INTOITER},
{"next", Kind::ITERATOR_NEXT},

{"eq", Kind::EQ},
{"partial_ord", Kind::PARTIAL_ORD},
}};

tl::optional<LangItem::Kind>
Expand Down Expand Up @@ -145,6 +148,46 @@ LangItem::OperatorToLangItem (ArithmeticOrLogicalOperator op)
rust_unreachable ();
}

LangItem::Kind
LangItem::ComparisonToLangItem (ComparisonOperator op)
{
switch (op)
{
case ComparisonOperator::NOT_EQUAL:
case ComparisonOperator::EQUAL:
return LangItem::Kind::EQ;

case ComparisonOperator::GREATER_THAN:
case ComparisonOperator::LESS_THAN:
case ComparisonOperator::GREATER_OR_EQUAL:
case ComparisonOperator::LESS_OR_EQUAL:
return LangItem::Kind::PARTIAL_ORD;
}
}

std::string
LangItem::ComparisonToSegment (ComparisonOperator op)
{
switch (op)
{
case ComparisonOperator::NOT_EQUAL:
return "ne";
case ComparisonOperator::EQUAL:
return "eq";
case ComparisonOperator::GREATER_THAN:
return "gt";
case ComparisonOperator::LESS_THAN:
return "lt";
case ComparisonOperator::GREATER_OR_EQUAL:
return "ge";
case ComparisonOperator::LESS_OR_EQUAL:
return "le";
}

rust_unreachable ();
return std::string ();
}

LangItem::Kind
LangItem::CompoundAssignmentOperatorToLangItem (ArithmeticOrLogicalOperator op)
{
Expand Down
5 changes: 5 additions & 0 deletions gcc/rust/util/rust-lang-item.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class LangItem

NEGATION,
NOT,
EQ,
PARTIAL_ORD,

ADD_ASSIGN,
SUB_ASSIGN,
Expand Down Expand Up @@ -136,6 +138,9 @@ class LangItem
static Kind
CompoundAssignmentOperatorToLangItem (ArithmeticOrLogicalOperator op);
static Kind NegationOperatorToLangItem (NegationOperator op);
static Kind ComparisonToLangItem (ComparisonOperator op);

static std::string ComparisonToSegment (ComparisonOperator op);
};

} // namespace Rust
Expand Down
8 changes: 4 additions & 4 deletions gcc/rust/util/rust-operators.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ enum class ComparisonOperator
{
EQUAL, // std::cmp::PartialEq::eq
NOT_EQUAL, // std::cmp::PartialEq::ne
GREATER_THAN, // std::cmp::PartialEq::gt
LESS_THAN, // std::cmp::PartialEq::lt
GREATER_OR_EQUAL, // std::cmp::PartialEq::ge
LESS_OR_EQUAL // std::cmp::PartialEq::le
GREATER_THAN, // std::cmp::PartialOrd::gt
LESS_THAN, // std::cmp::PartialOrd::lt
GREATER_OR_EQUAL, // std::cmp::PartialOrd::ge
LESS_OR_EQUAL // std::cmp::PartialOrd::le
};

enum class LazyBooleanOperator
Expand Down
78 changes: 78 additions & 0 deletions gcc/testsuite/rust/compile/cmp1.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// { dg-options "-w" }
// taken from https://github.com/rust-lang/rust/blob/e1884a8e3c3e813aada8254edfa120e85bf5ffca/library/core/src/cmp.rs#L98

#[lang = "sized"]
pub trait Sized {}

#[lang = "eq"]
#[stable(feature = "rust1", since = "1.0.0")]
#[doc(alias = "==")]
#[doc(alias = "!=")]
pub trait PartialEq<Rhs: ?Sized = Self> {
/// This method tests for `self` and `other` values to be equal, and is used
/// by `==`.
#[must_use]
#[stable(feature = "rust1", since = "1.0.0")]
fn eq(&self, other: &Rhs) -> bool;

/// This method tests for `!=`.
#[inline]
#[must_use]
#[stable(feature = "rust1", since = "1.0.0")]
fn ne(&self, other: &Rhs) -> bool {
!self.eq(other)
}
}

enum BookFormat {
Paperback,
Hardback,
Ebook,
}

impl PartialEq<BookFormat> for BookFormat {
fn eq(&self, other: &BookFormat) -> bool {
self == other
}
}

pub struct Book {
isbn: i32,
format: BookFormat,
}

// Implement <Book> == <BookFormat> comparisons
impl PartialEq<BookFormat> for Book {
fn eq(&self, other: &BookFormat) -> bool {
self.format == *other
}
}

// Implement <BookFormat> == <Book> comparisons
impl PartialEq<Book> for BookFormat {
fn eq(&self, other: &Book) -> bool {
*self == other.format
}
}

// Implement <Book> == <Book> comparisons
impl PartialEq<Book> for Book {
fn eq(&self, other: &Book) -> bool {
self.isbn == other.isbn
}
}

pub fn main() {
let b1 = Book {
isbn: 1,
format: BookFormat::Paperback,
};
let b2 = Book {
isbn: 2,
format: BookFormat::Paperback,
};

let _c1: bool = b1 == BookFormat::Paperback;
let _c2: bool = BookFormat::Paperback == b2;
let _c3: bool = b1 != b2;
}

0 comments on commit c24c1a8

Please sign in to comment.