diff --git a/gcc/rust/backend/rust-compile-expr.cc b/gcc/rust/backend/rust-compile-expr.cc index 7323413bfce2..46b9b0adccfd 100644 --- a/gcc/rust/backend/rust-compile-expr.cc +++ b/gcc/rust/backend/rust-compile-expr.cc @@ -279,6 +279,26 @@ CompileExpr::visit (HIR::ComparisonExpr &expr) auto rhs = CompileExpr::Compile (expr.get_rhs (), ctx); auto location = expr.get_locus (); + // this might be an operator overload situation lets check + TyTy::FnType *fntype; + bool is_op_overload = ctx->get_tyctx ()->lookup_operator_overload ( + expr.get_mappings ().get_hirid (), &fntype); + if (is_op_overload) + { + auto seg_name = LangItem::ComparisonToSegment (expr.get_expr_type ()); + auto segment = HIR::PathIdentSegment (seg_name); + auto lang_item_type + = LangItem::ComparisonToLangItem (expr.get_expr_type ()); + + rhs = address_expression (rhs, EXPR_LOCATION (rhs)); + + translated = resolve_operator_overload ( + lang_item_type, expr, lhs, rhs, expr.get_lhs (), + tl::optional> (expr.get_rhs ()), + segment); + return; + } + translated = Backend::comparison_expression (op, lhs, rhs, location); } @@ -1478,7 +1498,8 @@ CompileExpr::get_receiver_from_dyn (const TyTy::DynamicObjectType *dyn, tree CompileExpr::resolve_operator_overload ( LangItem::Kind lang_item_type, HIR::OperatorExprMeta expr, tree lhs, tree rhs, - HIR::Expr &lhs_expr, tl::optional> rhs_expr) + HIR::Expr &lhs_expr, tl::optional> rhs_expr, + HIR::PathIdentSegment specified_segment) { TyTy::FnType *fntype; bool is_op_overload = ctx->get_tyctx ()->lookup_operator_overload ( @@ -1499,7 +1520,10 @@ CompileExpr::resolve_operator_overload ( } // lookup compiled functions since it may have already been compiled - HIR::PathIdentSegment segment_name (LangItem::ToString (lang_item_type)); + HIR::PathIdentSegment segment_name + = specified_segment.is_error () + ? HIR::PathIdentSegment (LangItem::ToString (lang_item_type)) + : specified_segment; tree fn_expr = resolve_method_address (fntype, receiver, expr.get_locus ()); // lookup the autoderef mappings diff --git a/gcc/rust/backend/rust-compile-expr.h b/gcc/rust/backend/rust-compile-expr.h index 9e7af42c64f2..45e0d3c350d0 100644 --- a/gcc/rust/backend/rust-compile-expr.h +++ b/gcc/rust/backend/rust-compile-expr.h @@ -99,7 +99,9 @@ class CompileExpr : private HIRCompileBase, protected HIR::HIRExpressionVisitor tree resolve_operator_overload ( LangItem::Kind lang_item_type, HIR::OperatorExprMeta expr, tree lhs, tree rhs, HIR::Expr &lhs_expr, - tl::optional> rhs_expr); + tl::optional> rhs_expr, + HIR::PathIdentSegment specified_segment + = HIR::PathIdentSegment::create_error ()); tree compile_bool_literal (const HIR::LiteralExpr &expr, const TyTy::BaseType *tyty); diff --git a/gcc/rust/hir/tree/rust-hir-expr.cc b/gcc/rust/hir/tree/rust-hir-expr.cc index 4a902c655947..2ded789e60b1 100644 --- a/gcc/rust/hir/tree/rust-hir-expr.cc +++ b/gcc/rust/hir/tree/rust-hir-expr.cc @@ -1298,6 +1298,12 @@ OperatorExprMeta::OperatorExprMeta (HIR::ArrayIndexExpr &expr) locus (expr.get_locus ()) {} +OperatorExprMeta::OperatorExprMeta (HIR::ComparisonExpr &expr) + : node_mappings (expr.get_mappings ()), + lvalue_mappings (expr.get_expr ().get_mappings ()), + locus (expr.get_locus ()) +{} + AnonConst::AnonConst (NodeId id, std::unique_ptr expr) : id (id), expr (std::move (expr)) { diff --git a/gcc/rust/hir/tree/rust-hir-expr.h b/gcc/rust/hir/tree/rust-hir-expr.h index e956108dc53f..46039270f729 100644 --- a/gcc/rust/hir/tree/rust-hir-expr.h +++ b/gcc/rust/hir/tree/rust-hir-expr.h @@ -2816,6 +2816,8 @@ class OperatorExprMeta OperatorExprMeta (HIR::ArrayIndexExpr &expr); + OperatorExprMeta (HIR::ComparisonExpr &expr); + const Analysis::NodeMapping &get_mappings () const { return node_mappings; } const Analysis::NodeMapping &get_lvalue_mappings () const diff --git a/gcc/rust/typecheck/rust-hir-type-check-expr.cc b/gcc/rust/typecheck/rust-hir-type-check-expr.cc index 7daa27195db6..1cfd855de0c0 100644 --- a/gcc/rust/typecheck/rust-hir-type-check-expr.cc +++ b/gcc/rust/typecheck/rust-hir-type-check-expr.cc @@ -344,6 +344,21 @@ TypeCheckExpr::visit (HIR::ComparisonExpr &expr) auto lhs = TypeCheckExpr::Resolve (expr.get_lhs ()); auto rhs = TypeCheckExpr::Resolve (expr.get_rhs ()); + auto borrwed_rhs + = new TyTy::ReferenceType (mappings.get_next_hir_id (), + TyTy::TyVar (rhs->get_ref ()), Mutability::Imm); + context->insert_implicit_type (borrwed_rhs->get_ref (), borrwed_rhs); + + auto seg_name = LangItem::ComparisonToSegment (expr.get_expr_type ()); + auto segment = HIR::PathIdentSegment (seg_name); + auto lang_item_type = LangItem::ComparisonToLangItem (expr.get_expr_type ()); + + bool operator_overloaded + = resolve_operator_overload (lang_item_type, expr, lhs, borrwed_rhs, + segment); + if (operator_overloaded) + return; + unify_site (expr.get_mappings ().get_hirid (), TyTy::TyWithLocation (lhs, expr.get_lhs ().get_locus ()), TyTy::TyWithLocation (rhs, expr.get_rhs ().get_locus ()), @@ -1638,10 +1653,10 @@ TypeCheckExpr::visit (HIR::ClosureExpr &expr) } bool -TypeCheckExpr::resolve_operator_overload (LangItem::Kind lang_item_type, - HIR::OperatorExprMeta expr, - TyTy::BaseType *lhs, - TyTy::BaseType *rhs) +TypeCheckExpr::resolve_operator_overload ( + LangItem::Kind lang_item_type, HIR::OperatorExprMeta expr, + TyTy::BaseType *lhs, TyTy::BaseType *rhs, + HIR::PathIdentSegment specified_segment) { // look up lang item for arithmetic type std::string associated_item_name = LangItem::ToString (lang_item_type); @@ -1659,7 +1674,9 @@ TypeCheckExpr::resolve_operator_overload (LangItem::Kind lang_item_type, current_context = context->peek_context (); } - auto segment = HIR::PathIdentSegment (associated_item_name); + auto segment = specified_segment.is_error () + ? HIR::PathIdentSegment (associated_item_name) + : specified_segment; auto candidates = MethodResolver::Probe (lhs, segment); // remove any recursive candidates diff --git a/gcc/rust/typecheck/rust-hir-type-check-expr.h b/gcc/rust/typecheck/rust-hir-type-check-expr.h index 51fdd934da5d..3ceef7a521e7 100644 --- a/gcc/rust/typecheck/rust-hir-type-check-expr.h +++ b/gcc/rust/typecheck/rust-hir-type-check-expr.h @@ -97,7 +97,9 @@ class TypeCheckExpr : private TypeCheckBase, private HIR::HIRExpressionVisitor protected: bool resolve_operator_overload (LangItem::Kind lang_item_type, HIR::OperatorExprMeta expr, - TyTy::BaseType *lhs, TyTy::BaseType *rhs); + TyTy::BaseType *lhs, TyTy::BaseType *rhs, + HIR::PathIdentSegment specified_segment + = HIR::PathIdentSegment::create_error ()); bool resolve_fn_trait_call (HIR::CallExpr &expr, TyTy::BaseType *function_tyty, diff --git a/gcc/rust/util/rust-lang-item.cc b/gcc/rust/util/rust-lang-item.cc index 76fcd348e3f4..896e948039eb 100644 --- a/gcc/rust/util/rust-lang-item.cc +++ b/gcc/rust/util/rust-lang-item.cc @@ -98,6 +98,9 @@ const BiMap Rust::LangItem::lang_items = {{ {"into_iter", Kind::INTOITER_INTOITER}, {"next", Kind::ITERATOR_NEXT}, + + {"eq", Kind::EQ}, + {"partial_ord", Kind::PARTIAL_ORD}, }}; tl::optional @@ -145,6 +148,46 @@ LangItem::OperatorToLangItem (ArithmeticOrLogicalOperator op) rust_unreachable (); } +LangItem::Kind +LangItem::ComparisonToLangItem (ComparisonOperator op) +{ + switch (op) + { + case ComparisonOperator::NOT_EQUAL: + case ComparisonOperator::EQUAL: + return LangItem::Kind::EQ; + + case ComparisonOperator::GREATER_THAN: + case ComparisonOperator::LESS_THAN: + case ComparisonOperator::GREATER_OR_EQUAL: + case ComparisonOperator::LESS_OR_EQUAL: + return LangItem::Kind::PARTIAL_ORD; + } +} + +std::string +LangItem::ComparisonToSegment (ComparisonOperator op) +{ + switch (op) + { + case ComparisonOperator::NOT_EQUAL: + return "ne"; + case ComparisonOperator::EQUAL: + return "eq"; + case ComparisonOperator::GREATER_THAN: + return "gt"; + case ComparisonOperator::LESS_THAN: + return "lt"; + case ComparisonOperator::GREATER_OR_EQUAL: + return "ge"; + case ComparisonOperator::LESS_OR_EQUAL: + return "le"; + } + + rust_unreachable (); + return std::string (); +} + LangItem::Kind LangItem::CompoundAssignmentOperatorToLangItem (ArithmeticOrLogicalOperator op) { diff --git a/gcc/rust/util/rust-lang-item.h b/gcc/rust/util/rust-lang-item.h index 5d57533499ee..35ee5c2a2c3a 100644 --- a/gcc/rust/util/rust-lang-item.h +++ b/gcc/rust/util/rust-lang-item.h @@ -45,6 +45,8 @@ class LangItem NEGATION, NOT, + EQ, + PARTIAL_ORD, ADD_ASSIGN, SUB_ASSIGN, @@ -136,6 +138,9 @@ class LangItem static Kind CompoundAssignmentOperatorToLangItem (ArithmeticOrLogicalOperator op); static Kind NegationOperatorToLangItem (NegationOperator op); + static Kind ComparisonToLangItem (ComparisonOperator op); + + static std::string ComparisonToSegment (ComparisonOperator op); }; } // namespace Rust diff --git a/gcc/rust/util/rust-operators.h b/gcc/rust/util/rust-operators.h index e3d6205c780e..f460f4f798f2 100644 --- a/gcc/rust/util/rust-operators.h +++ b/gcc/rust/util/rust-operators.h @@ -43,10 +43,10 @@ enum class ComparisonOperator { EQUAL, // std::cmp::PartialEq::eq NOT_EQUAL, // std::cmp::PartialEq::ne - GREATER_THAN, // std::cmp::PartialEq::gt - LESS_THAN, // std::cmp::PartialEq::lt - GREATER_OR_EQUAL, // std::cmp::PartialEq::ge - LESS_OR_EQUAL // std::cmp::PartialEq::le + GREATER_THAN, // std::cmp::PartialOrd::gt + LESS_THAN, // std::cmp::PartialOrd::lt + GREATER_OR_EQUAL, // std::cmp::PartialOrd::ge + LESS_OR_EQUAL // std::cmp::PartialOrd::le }; enum class LazyBooleanOperator diff --git a/gcc/testsuite/rust/compile/cmp1.rs b/gcc/testsuite/rust/compile/cmp1.rs new file mode 100644 index 000000000000..4da5b1c01fc3 --- /dev/null +++ b/gcc/testsuite/rust/compile/cmp1.rs @@ -0,0 +1,78 @@ +// { dg-options "-w" } +// taken from https://github.com/rust-lang/rust/blob/e1884a8e3c3e813aada8254edfa120e85bf5ffca/library/core/src/cmp.rs#L98 + +#[lang = "sized"] +pub trait Sized {} + +#[lang = "eq"] +#[stable(feature = "rust1", since = "1.0.0")] +#[doc(alias = "==")] +#[doc(alias = "!=")] +pub trait PartialEq { + /// This method tests for `self` and `other` values to be equal, and is used + /// by `==`. + #[must_use] + #[stable(feature = "rust1", since = "1.0.0")] + fn eq(&self, other: &Rhs) -> bool; + + /// This method tests for `!=`. + #[inline] + #[must_use] + #[stable(feature = "rust1", since = "1.0.0")] + fn ne(&self, other: &Rhs) -> bool { + !self.eq(other) + } +} + +enum BookFormat { + Paperback, + Hardback, + Ebook, +} + +impl PartialEq for BookFormat { + fn eq(&self, other: &BookFormat) -> bool { + self == other + } +} + +pub struct Book { + isbn: i32, + format: BookFormat, +} + +// Implement == comparisons +impl PartialEq for Book { + fn eq(&self, other: &BookFormat) -> bool { + self.format == *other + } +} + +// Implement == comparisons +impl PartialEq for BookFormat { + fn eq(&self, other: &Book) -> bool { + *self == other.format + } +} + +// Implement == comparisons +impl PartialEq for Book { + fn eq(&self, other: &Book) -> bool { + self.isbn == other.isbn + } +} + +pub fn main() { + let b1 = Book { + isbn: 1, + format: BookFormat::Paperback, + }; + let b2 = Book { + isbn: 2, + format: BookFormat::Paperback, + }; + + let _c1: bool = b1 == BookFormat::Paperback; + let _c2: bool = BookFormat::Paperback == b2; + let _c3: bool = b1 != b2; +}