Skip to content

Commit

Permalink
gccrs: add support for lang_item eq and PartialEq trait
Browse files Browse the repository at this point in the history
The Eq and Partial Ord are very similar to the operator overloads
we support for add/sub/etc... but they differ in that usually the
function call name matches the name of the lang item. This time
we need to have support to send in a new path for the method call
on the lang item we want instead of just the name of the lang item.

NOTE: this test case doesnt work correctly yet we need to support
the derive of partial eq on enums to generate the correct comparison
code for that.

Fixes #3302

gcc/rust/ChangeLog:

	* backend/rust-compile-expr.cc (CompileExpr::visit): handle partial_eq possible call
	* backend/rust-compile-expr.h: handle case where lang item calls differ from name
	* hir/tree/rust-hir-expr.cc (OperatorExprMeta::OperatorExprMeta): new helper
	* hir/tree/rust-hir-expr.h: likewise
	* typecheck/rust-hir-type-check-expr.cc (TypeCheckExpr::visit): handle partial_eq
	(TypeCheckExpr::resolve_operator_overload): likewise
	* typecheck/rust-hir-type-check-expr.h: likewise
	* util/rust-lang-item.cc (LangItem::ComparisonToLangItem): map comparison to lang item
	(LangItem::ComparisonToSegment): likewise
	* util/rust-lang-item.h: new lang items PartialOrd and Eq
	* util/rust-operators.h (enum class): likewise

gcc/testsuite/ChangeLog:

	* rust/compile/nr2/exclude: nr2 cant handle this
	* rust/compile/cmp1.rs: New test.

Signed-off-by: Philip Herron <herron.philip@googlemail.com>
  • Loading branch information
philberty committed Jan 7, 2025
1 parent 48aa71c commit b603ed0
Show file tree
Hide file tree
Showing 11 changed files with 194 additions and 13 deletions.
28 changes: 26 additions & 2 deletions gcc/rust/backend/rust-compile-expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,26 @@ CompileExpr::visit (HIR::ComparisonExpr &expr)
auto rhs = CompileExpr::Compile (expr.get_rhs (), ctx);
auto location = expr.get_locus ();

// this might be an operator overload situation lets check
TyTy::FnType *fntype;
bool is_op_overload = ctx->get_tyctx ()->lookup_operator_overload (
expr.get_mappings ().get_hirid (), &fntype);
if (is_op_overload)
{
auto seg_name = LangItem::ComparisonToSegment (expr.get_expr_type ());
auto segment = HIR::PathIdentSegment (seg_name);
auto lang_item_type
= LangItem::ComparisonToLangItem (expr.get_expr_type ());

rhs = address_expression (rhs, EXPR_LOCATION (rhs));

translated = resolve_operator_overload (
lang_item_type, expr, lhs, rhs, expr.get_lhs (),
tl::optional<std::reference_wrapper<HIR::Expr>> (expr.get_rhs ()),
segment);
return;
}

translated = Backend::comparison_expression (op, lhs, rhs, location);
}

Expand Down Expand Up @@ -1478,7 +1498,8 @@ CompileExpr::get_receiver_from_dyn (const TyTy::DynamicObjectType *dyn,
tree
CompileExpr::resolve_operator_overload (
LangItem::Kind lang_item_type, HIR::OperatorExprMeta expr, tree lhs, tree rhs,
HIR::Expr &lhs_expr, tl::optional<std::reference_wrapper<HIR::Expr>> rhs_expr)
HIR::Expr &lhs_expr, tl::optional<std::reference_wrapper<HIR::Expr>> rhs_expr,
HIR::PathIdentSegment specified_segment)
{
TyTy::FnType *fntype;
bool is_op_overload = ctx->get_tyctx ()->lookup_operator_overload (
Expand All @@ -1499,7 +1520,10 @@ CompileExpr::resolve_operator_overload (
}

// lookup compiled functions since it may have already been compiled
HIR::PathIdentSegment segment_name (LangItem::ToString (lang_item_type));
HIR::PathIdentSegment segment_name
= specified_segment.is_error ()
? HIR::PathIdentSegment (LangItem::ToString (lang_item_type))
: specified_segment;
tree fn_expr = resolve_method_address (fntype, receiver, expr.get_locus ());

// lookup the autoderef mappings
Expand Down
4 changes: 3 additions & 1 deletion gcc/rust/backend/rust-compile-expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ class CompileExpr : private HIRCompileBase, protected HIR::HIRExpressionVisitor
tree resolve_operator_overload (
LangItem::Kind lang_item_type, HIR::OperatorExprMeta expr, tree lhs,
tree rhs, HIR::Expr &lhs_expr,
tl::optional<std::reference_wrapper<HIR::Expr>> rhs_expr);
tl::optional<std::reference_wrapper<HIR::Expr>> rhs_expr,
HIR::PathIdentSegment specified_segment
= HIR::PathIdentSegment::create_error ());

tree compile_bool_literal (const HIR::LiteralExpr &expr,
const TyTy::BaseType *tyty);
Expand Down
6 changes: 6 additions & 0 deletions gcc/rust/hir/tree/rust-hir-expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1298,6 +1298,12 @@ OperatorExprMeta::OperatorExprMeta (HIR::ArrayIndexExpr &expr)
locus (expr.get_locus ())
{}

OperatorExprMeta::OperatorExprMeta (HIR::ComparisonExpr &expr)
: node_mappings (expr.get_mappings ()),
lvalue_mappings (expr.get_expr ().get_mappings ()),
locus (expr.get_locus ())
{}

AnonConst::AnonConst (NodeId id, std::unique_ptr<Expr> expr)
: id (id), expr (std::move (expr))
{
Expand Down
2 changes: 2 additions & 0 deletions gcc/rust/hir/tree/rust-hir-expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -2816,6 +2816,8 @@ class OperatorExprMeta

OperatorExprMeta (HIR::ArrayIndexExpr &expr);

OperatorExprMeta (HIR::ComparisonExpr &expr);

const Analysis::NodeMapping &get_mappings () const { return node_mappings; }

const Analysis::NodeMapping &get_lvalue_mappings () const
Expand Down
27 changes: 22 additions & 5 deletions gcc/rust/typecheck/rust-hir-type-check-expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,21 @@ TypeCheckExpr::visit (HIR::ComparisonExpr &expr)
auto lhs = TypeCheckExpr::Resolve (expr.get_lhs ());
auto rhs = TypeCheckExpr::Resolve (expr.get_rhs ());

auto borrwed_rhs
= new TyTy::ReferenceType (mappings.get_next_hir_id (),
TyTy::TyVar (rhs->get_ref ()), Mutability::Imm);
context->insert_implicit_type (borrwed_rhs->get_ref (), borrwed_rhs);

auto seg_name = LangItem::ComparisonToSegment (expr.get_expr_type ());
auto segment = HIR::PathIdentSegment (seg_name);
auto lang_item_type = LangItem::ComparisonToLangItem (expr.get_expr_type ());

bool operator_overloaded
= resolve_operator_overload (lang_item_type, expr, lhs, borrwed_rhs,
segment);
if (operator_overloaded)
return;

unify_site (expr.get_mappings ().get_hirid (),
TyTy::TyWithLocation (lhs, expr.get_lhs ().get_locus ()),
TyTy::TyWithLocation (rhs, expr.get_rhs ().get_locus ()),
Expand Down Expand Up @@ -1638,10 +1653,10 @@ TypeCheckExpr::visit (HIR::ClosureExpr &expr)
}

bool
TypeCheckExpr::resolve_operator_overload (LangItem::Kind lang_item_type,
HIR::OperatorExprMeta expr,
TyTy::BaseType *lhs,
TyTy::BaseType *rhs)
TypeCheckExpr::resolve_operator_overload (
LangItem::Kind lang_item_type, HIR::OperatorExprMeta expr,
TyTy::BaseType *lhs, TyTy::BaseType *rhs,
HIR::PathIdentSegment specified_segment)
{
// look up lang item for arithmetic type
std::string associated_item_name = LangItem::ToString (lang_item_type);
Expand All @@ -1659,7 +1674,9 @@ TypeCheckExpr::resolve_operator_overload (LangItem::Kind lang_item_type,
current_context = context->peek_context ();
}

auto segment = HIR::PathIdentSegment (associated_item_name);
auto segment = specified_segment.is_error ()
? HIR::PathIdentSegment (associated_item_name)
: specified_segment;
auto candidates = MethodResolver::Probe (lhs, segment);

// remove any recursive candidates
Expand Down
4 changes: 3 additions & 1 deletion gcc/rust/typecheck/rust-hir-type-check-expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ class TypeCheckExpr : private TypeCheckBase, private HIR::HIRExpressionVisitor
protected:
bool resolve_operator_overload (LangItem::Kind lang_item_type,
HIR::OperatorExprMeta expr,
TyTy::BaseType *lhs, TyTy::BaseType *rhs);
TyTy::BaseType *lhs, TyTy::BaseType *rhs,
HIR::PathIdentSegment specified_segment
= HIR::PathIdentSegment::create_error ());

bool resolve_fn_trait_call (HIR::CallExpr &expr,
TyTy::BaseType *function_tyty,
Expand Down
44 changes: 44 additions & 0 deletions gcc/rust/util/rust-lang-item.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ const BiMap<std::string, LangItem::Kind> Rust::LangItem::lang_items = {{

{"into_iter", Kind::INTOITER_INTOITER},
{"next", Kind::ITERATOR_NEXT},

{"eq", Kind::EQ},
{"partial_ord", Kind::PARTIAL_ORD},
}};

tl::optional<LangItem::Kind>
Expand Down Expand Up @@ -145,6 +148,47 @@ LangItem::OperatorToLangItem (ArithmeticOrLogicalOperator op)
rust_unreachable ();
}

LangItem::Kind
LangItem::ComparisonToLangItem (ComparisonOperator op)
{
switch (op)
{
case ComparisonOperator::NOT_EQUAL:
case ComparisonOperator::EQUAL:
return LangItem::Kind::EQ;

case ComparisonOperator::GREATER_THAN:
case ComparisonOperator::LESS_THAN:
case ComparisonOperator::GREATER_OR_EQUAL:
case ComparisonOperator::LESS_OR_EQUAL:
return LangItem::Kind::PARTIAL_ORD;
}

rust_unreachable ();
}

std::string
LangItem::ComparisonToSegment (ComparisonOperator op)
{
switch (op)
{
case ComparisonOperator::NOT_EQUAL:
return "ne";
case ComparisonOperator::EQUAL:
return "eq";
case ComparisonOperator::GREATER_THAN:
return "gt";
case ComparisonOperator::LESS_THAN:
return "lt";
case ComparisonOperator::GREATER_OR_EQUAL:
return "ge";
case ComparisonOperator::LESS_OR_EQUAL:
return "le";
}

rust_unreachable ();
}

LangItem::Kind
LangItem::CompoundAssignmentOperatorToLangItem (ArithmeticOrLogicalOperator op)
{
Expand Down
5 changes: 5 additions & 0 deletions gcc/rust/util/rust-lang-item.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class LangItem

NEGATION,
NOT,
EQ,
PARTIAL_ORD,

ADD_ASSIGN,
SUB_ASSIGN,
Expand Down Expand Up @@ -136,6 +138,9 @@ class LangItem
static Kind
CompoundAssignmentOperatorToLangItem (ArithmeticOrLogicalOperator op);
static Kind NegationOperatorToLangItem (NegationOperator op);
static Kind ComparisonToLangItem (ComparisonOperator op);

static std::string ComparisonToSegment (ComparisonOperator op);
};

} // namespace Rust
Expand Down
8 changes: 4 additions & 4 deletions gcc/rust/util/rust-operators.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ enum class ComparisonOperator
{
EQUAL, // std::cmp::PartialEq::eq
NOT_EQUAL, // std::cmp::PartialEq::ne
GREATER_THAN, // std::cmp::PartialEq::gt
LESS_THAN, // std::cmp::PartialEq::lt
GREATER_OR_EQUAL, // std::cmp::PartialEq::ge
LESS_OR_EQUAL // std::cmp::PartialEq::le
GREATER_THAN, // std::cmp::PartialOrd::gt
LESS_THAN, // std::cmp::PartialOrd::lt
GREATER_OR_EQUAL, // std::cmp::PartialOrd::ge
LESS_OR_EQUAL // std::cmp::PartialOrd::le
};

enum class LazyBooleanOperator
Expand Down
78 changes: 78 additions & 0 deletions gcc/testsuite/rust/compile/cmp1.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// { dg-options "-w" }
// taken from https://github.com/rust-lang/rust/blob/e1884a8e3c3e813aada8254edfa120e85bf5ffca/library/core/src/cmp.rs#L98

#[lang = "sized"]
pub trait Sized {}

#[lang = "eq"]
#[stable(feature = "rust1", since = "1.0.0")]
#[doc(alias = "==")]
#[doc(alias = "!=")]
pub trait PartialEq<Rhs: ?Sized = Self> {
/// This method tests for `self` and `other` values to be equal, and is used
/// by `==`.
#[must_use]
#[stable(feature = "rust1", since = "1.0.0")]
fn eq(&self, other: &Rhs) -> bool;

/// This method tests for `!=`.
#[inline]
#[must_use]
#[stable(feature = "rust1", since = "1.0.0")]
fn ne(&self, other: &Rhs) -> bool {
!self.eq(other)
}
}

enum BookFormat {
Paperback,
Hardback,
Ebook,
}

impl PartialEq<BookFormat> for BookFormat {
fn eq(&self, other: &BookFormat) -> bool {
self == other
}
}

pub struct Book {
isbn: i32,
format: BookFormat,
}

// Implement <Book> == <BookFormat> comparisons
impl PartialEq<BookFormat> for Book {
fn eq(&self, other: &BookFormat) -> bool {
self.format == *other
}
}

// Implement <BookFormat> == <Book> comparisons
impl PartialEq<Book> for BookFormat {
fn eq(&self, other: &Book) -> bool {
*self == other.format
}
}

// Implement <Book> == <Book> comparisons
impl PartialEq<Book> for Book {
fn eq(&self, other: &Book) -> bool {
self.isbn == other.isbn
}
}

pub fn main() {
let b1 = Book {
isbn: 1,
format: BookFormat::Paperback,
};
let b2 = Book {
isbn: 2,
format: BookFormat::Paperback,
};

let _c1: bool = b1 == BookFormat::Paperback;
let _c2: bool = BookFormat::Paperback == b2;
let _c3: bool = b1 != b2;
}
1 change: 1 addition & 0 deletions gcc/testsuite/rust/compile/nr2/exclude
Original file line number Diff line number Diff line change
Expand Up @@ -204,4 +204,5 @@ issue-266.rs
additional-trait-bounds2.rs
auto_traits2.rs
auto_traits3.rs
cmp1.rs
# please don't delete the trailing newline

0 comments on commit b603ed0

Please sign in to comment.