diff --git a/libredex/CallGraph.cpp b/libredex/CallGraph.cpp index 864e077e25..410a3af062 100644 --- a/libredex/CallGraph.cpp +++ b/libredex/CallGraph.cpp @@ -136,11 +136,6 @@ RootAndDynamic MultipleCalleeBaseStrategy::get_roots() const { roots.insert(method); return; } - // For methods marked with DoNotInline, we also add to dynamic methods set - // to avoid propagating return value. - if (method->rstate.dont_inline()) { - dynamic_methods.emplace(method); - } if (!root(method) && !method::is_argless_init(method) && !(method->is_virtual() && is_interface(type_class(method->get_class())) && @@ -307,7 +302,7 @@ MultipleCalleeStrategy::MultipleCalleeStrategy( if (callee == nullptr) { return; } - if (!callee->is_virtual()) { + if (!callee->is_virtual() || insn->opcode() == OPCODE_INVOKE_SUPER) { return; } if (!concurrent_callees.insert(callee)) { @@ -348,7 +343,8 @@ CallSites MultipleCalleeStrategy::get_callsites(const DexMethod* method) const { if (callee == nullptr) { return editable_cfg_adapter::LOOP_CONTINUE; } - if (is_definitely_virtual(callee)) { + if (is_definitely_virtual(callee) && + insn->opcode() != OPCODE_INVOKE_SUPER) { // For true virtual callees, add the callee itself and all of its // overrides if they are not in big overrides. if (m_big_override.count_unsafe(callee)) { @@ -357,12 +353,10 @@ CallSites MultipleCalleeStrategy::get_callsites(const DexMethod* method) const { if (callee->get_code()) { callsites.emplace_back(callee, insn); } - if (insn->opcode() != OPCODE_INVOKE_SUPER) { - const auto& overriding_methods = - get_ordered_overriding_methods_with_code(callee); - for (auto overriding_method : overriding_methods) { - callsites.emplace_back(overriding_method, insn); - } + const auto& overriding_methods = + get_ordered_overriding_methods_with_code(callee); + for (auto overriding_method : overriding_methods) { + callsites.emplace_back(overriding_method, insn); } } else if (callee->is_concrete()) { callsites.emplace_back(callee, insn); @@ -559,8 +553,21 @@ const MethodSet& resolve_callees_in_graph(const Graph& graph, return no_methods; } -bool method_is_dynamic(const Graph& graph, const DexMethod* method) { - return graph.get_dynamic_methods().count(method); +bool invoke_is_dynamic(const Graph& graph, const IRInstruction* insn) { + auto* callee = resolve_invoke_method(insn); + if (callee == nullptr) { + return true; + } + // For methods marked with DoNotInline, we also treat them like dynamic + // methods to avoid propagating return value. + if (callee->rstate.dont_inline()) { + return true; + } + if (insn->opcode() != OPCODE_INVOKE_VIRTUAL && + insn->opcode() != OPCODE_INVOKE_INTERFACE) { + return false; + } + return graph.get_dynamic_methods().count(callee); } CallgraphStats get_num_nodes_edges(const Graph& graph) { diff --git a/libredex/CallGraph.h b/libredex/CallGraph.h index d06d469589..a16838290b 100644 --- a/libredex/CallGraph.h +++ b/libredex/CallGraph.h @@ -326,7 +326,7 @@ const MethodSet& resolve_callees_in_graph(const Graph& graph, const MethodVector& get_callee_to_callers(const Graph& graph, const DexMethod* callee); -bool method_is_dynamic(const Graph& graph, const DexMethod* method); +bool invoke_is_dynamic(const Graph& graph, const IRInstruction* insn); struct CallgraphStats { uint32_t num_nodes; diff --git a/service/constant-propagation/ConstantPropagationWholeProgramState.cpp b/service/constant-propagation/ConstantPropagationWholeProgramState.cpp index 7ab869e99c..1332989bfb 100644 --- a/service/constant-propagation/ConstantPropagationWholeProgramState.cpp +++ b/service/constant-propagation/ConstantPropagationWholeProgramState.cpp @@ -358,11 +358,7 @@ bool WholeProgramAwareAnalyzer::analyze_invoke( return false; } if (whole_program_state->has_call_graph()) { - auto method = resolve_invoke_method(insn); - if (method == nullptr) { - return false; - } - if (whole_program_state->method_is_dynamic(method)) { + if (whole_program_state->invoke_is_dynamic(insn)) { return false; } auto value = whole_program_state->get_return_value_from_cg(insn); diff --git a/service/constant-propagation/ConstantPropagationWholeProgramState.h b/service/constant-propagation/ConstantPropagationWholeProgramState.h index 496f39d7d4..3ba1c64b18 100644 --- a/service/constant-propagation/ConstantPropagationWholeProgramState.h +++ b/service/constant-propagation/ConstantPropagationWholeProgramState.h @@ -122,8 +122,8 @@ class WholeProgramState { const call_graph::Graph* call_graph() const { return m_call_graph.get(); } - bool method_is_dynamic(const DexMethod* method) const { - return call_graph::method_is_dynamic(*m_call_graph, method); + bool invoke_is_dynamic(const IRInstruction* insn) const { + return call_graph::invoke_is_dynamic(*m_call_graph, insn); } private: @@ -187,8 +187,8 @@ class WholeProgramStateAccessor { bool has_call_graph() const { return m_wps.has_call_graph(); } - bool method_is_dynamic(const DexMethod* method) const { - return m_wps.method_is_dynamic(method); + bool invoke_is_dynamic(const IRInstruction* insn) const { + return m_wps.invoke_is_dynamic(insn); } ConstantValue get_field_value(const DexField* field) const { diff --git a/service/type-analysis/TypeAnalysisRuntimeAssert.cpp b/service/type-analysis/TypeAnalysisRuntimeAssert.cpp index ac21e24b81..cb2d7314dd 100644 --- a/service/type-analysis/TypeAnalysisRuntimeAssert.cpp +++ b/service/type-analysis/TypeAnalysisRuntimeAssert.cpp @@ -291,11 +291,10 @@ bool RuntimeAssertTransform::insert_return_value_assert( DexMethod* callee = nullptr; DexTypeDomain domain = DexTypeDomain::top(); if (wps.has_call_graph()) { - callee = resolve_invoke_method(insn); - if (callee == nullptr || wps.method_is_dynamic(callee)) { - domain = DexTypeDomain::top(); + if (wps.invoke_is_dynamic(insn)) { return false; } + callee = resolve_invoke_method(insn); domain = wps.get_return_type_from_cg(insn); } else { callee = resolve_method(insn->get_method(), opcode_to_search(insn)); diff --git a/service/type-analysis/WholeProgramState.cpp b/service/type-analysis/WholeProgramState.cpp index a7cd961500..c1b7b3d9e5 100644 --- a/service/type-analysis/WholeProgramState.cpp +++ b/service/type-analysis/WholeProgramState.cpp @@ -482,8 +482,7 @@ bool WholeProgramAwareAnalyzer::analyze_invoke( } if (whole_program_state->has_call_graph()) { - auto method = resolve_invoke_method(insn); - if (method == nullptr || whole_program_state->method_is_dynamic(method)) { + if (whole_program_state->invoke_is_dynamic(insn)) { env->set(RESULT_REGISTER, DexTypeDomain::top()); return false; } diff --git a/service/type-analysis/WholeProgramState.h b/service/type-analysis/WholeProgramState.h index fc098c39b9..1ca554f306 100644 --- a/service/type-analysis/WholeProgramState.h +++ b/service/type-analysis/WholeProgramState.h @@ -154,8 +154,8 @@ class WholeProgramState { return ret; } - bool method_is_dynamic(const DexMethod* method) const { - return call_graph::method_is_dynamic(*m_call_graph, method); + bool invoke_is_dynamic(const IRInstruction* insn) const { + return call_graph::invoke_is_dynamic(*m_call_graph, insn); } // For debugging diff --git a/test/integ/CallGraphTest.cpp b/test/integ/CallGraphTest.cpp index 45b98ac2c8..eecf2b7ba8 100644 --- a/test/integ/CallGraphTest.cpp +++ b/test/integ/CallGraphTest.cpp @@ -39,6 +39,9 @@ struct CallGraphTest : public RedexIntegrationTest { DexMethod* pure_ref_intf_return; DexMethod* pure_ref_3_return; DexMethod* pure_ref_3_init; + DexMethod* more_than_5_class_extends_1_init; + DexMethod* more_than_5_class_extends_1_return_super_num; + DexMethod* more_than_5_class_return_num; Scope scope; std::unique_ptr method_override_graph; @@ -140,6 +143,22 @@ struct CallGraphTest : public RedexIntegrationTest { pure_ref_3_init = DexMethod::get_method( "Lcom/facebook/redextest/PureRefImpl3;.:()V") ->as_def(); + + more_than_5_class_extends_1_init = DexMethod::get_method( + "Lcom/facebook/redextest/" + "MoreThan5ClassExtends1;.:()V") + ->as_def(); + + more_than_5_class_extends_1_return_super_num = + DexMethod::get_method( + "Lcom/facebook/redextest/" + "MoreThan5ClassExtends1;.returnSuperNum:()I") + ->as_def(); + + more_than_5_class_return_num = + DexMethod::get_method( + "Lcom/facebook/redextest/MoreThan5Class;.returnNum:()I") + ->as_def(); } std::vector get_callees(const call_graph::Graph& graph, @@ -211,15 +230,12 @@ TEST_F(CallGraphTest, test_multiple_callee_graph_entry) { TEST_F(CallGraphTest, test_multiple_callee_graph_clinit) { auto clinit_callees = get_callees(*multiple_graph, clinit); EXPECT_THAT(clinit_callees, - ::testing::UnorderedElementsAre(calls_returns_int, - base_foo, - extended_init, - less_impl3_init, - more_impl1_init, - less_impl1_return, - less_impl2_return, - less_impl3_return, - less_impl4_return)); + ::testing::UnorderedElementsAre( + calls_returns_int, base_foo, extended_init, less_impl3_init, + more_impl1_init, less_impl1_return, less_impl2_return, + less_impl3_return, less_impl4_return, + more_than_5_class_extends_1_init, + more_than_5_class_extends_1_return_super_num)); } TEST_F(CallGraphTest, test_multiple_callee_graph_return4) { @@ -244,3 +260,10 @@ TEST_F(CallGraphTest, test_multiple_callee_graph_extended_returns_int) { EXPECT_THAT(extendedextended_returns_int_callees, ::testing::UnorderedElementsAre(extended_returns_int)); } + +TEST_F(CallGraphTest, test_multiple_callee_graph_invoke_super) { + auto callees = get_callees(*multiple_graph, + more_than_5_class_extends_1_return_super_num); + EXPECT_THAT(callees, + ::testing::UnorderedElementsAre(more_than_5_class_return_num)); +} diff --git a/test/integ/CallGraphTest.java b/test/integ/CallGraphTest.java index 28cb3627a4..9dd50e6f18 100644 --- a/test/integ/CallGraphTest.java +++ b/test/integ/CallGraphTest.java @@ -17,6 +17,8 @@ public class CallGraphTest { int get1 = moreThan5.returnNum(); LessThan5 lessThan5 = new LessThan5Impl3(); int get3 = lessThan5.returnNum(); + MoreThan5ClassExtends1 moreThan5ClassExtends1 = new MoreThan5ClassExtends1(); + int get4 = moreThan5ClassExtends1.returnSuperNum(); } static int callsReturnsInt(Base b) { @@ -109,3 +111,49 @@ abstract class PureRefImpl2 extends PureRefImpl1 {} class PureRefImpl3 extends PureRefImpl2 { public int returnNum() { return 5; } } + +class MoreThan5Class { + public int returnNum() { + return 0; + } +} + +class MoreThan5ClassExtends1 extends MoreThan5Class { + public int returnNum() { + return 1; + } + + public int returnSuperNum() { + return super.returnNum(); + } +} + +class MoreThan5ClassExtends2 extends MoreThan5Class { + public int returnNum() { + return 2; + } +} + +class MoreThan5ClassExtends3 extends MoreThan5Class { + public int returnNum() { + return 3; + } +} + +class MoreThan5ClassExtends4 extends MoreThan5Class { + public int returnNum() { + return 4; + } +} + +class MoreThan5ClassExtends5 extends MoreThan5Class { + public int returnNum() { + return 5; + } +} + +class MoreThan5ClassExtends6 extends MoreThan5Class { + public int returnNum() { + return 6; + } +}