From a50ae3127127ce13c4efee9521766135ba88242d Mon Sep 17 00:00:00 2001 From: kahmed10 <15948690+kahmed10@users.noreply.github.com> Date: Fri, 22 Sep 2023 11:18:06 -0500 Subject: [PATCH] rel 5.7 - Optimize broadcast+transpose for scalars (#2205) --- src/optimize_module.cpp | 10 +++-- src/propagate_constant.cpp | 8 ++-- src/simplify_reshapes.cpp | 27 +++++++++++++- test/optimize_module_test.cpp | 65 +++++++++++++++++++++++++++++++++ test/simplify_reshapes_test.cpp | 51 +++++++++++++++++++++++++- 5 files changed, 152 insertions(+), 9 deletions(-) create mode 100644 test/optimize_module_test.cpp diff --git a/src/optimize_module.cpp b/src/optimize_module.cpp index 26d3757b4b6..94fa77d01ce 100644 --- a/src/optimize_module.cpp +++ b/src/optimize_module.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -36,8 +36,12 @@ void optimize_module::apply(module_pass_manager& mpm) const { for(int i = 0; i < 2; i++) { - mpm.run_pass(simplify_reshapes{}); - mpm.run_pass(simplify_algebra{}); + // loop to further optimize after initial transformations + for(int j = 0; j < 2; j++) + { + mpm.run_pass(simplify_reshapes{}); + mpm.run_pass(simplify_algebra{}); + } mpm.run_pass(eliminate_common_subexpression{}); mpm.run_pass(dead_code_elimination{}); mpm.run_pass(propagate_constant{}); diff --git a/src/propagate_constant.cpp b/src/propagate_constant.cpp index cf9044152f7..4b92a4854c8 100644 --- a/src/propagate_constant.cpp +++ b/src/propagate_constant.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -35,10 +35,10 @@ inline namespace MIGRAPHX_INLINE_NS { MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PROPAGATE_CONSTANT) -bool skip_propogate(instruction_ref ins) +bool skip_propagate(instruction_ref ins) { if(ins->name() == "contiguous") - return skip_propogate(ins->inputs().front()); + return skip_propagate(ins->inputs().front()); auto&& s = ins->get_shape(); if(s.broadcasted() and not s.scalar()) return true; @@ -47,7 +47,7 @@ bool skip_propogate(instruction_ref ins) return false; } -bool is_const_ins(instruction_ref ins) { return ins->can_eval() and not skip_propogate(ins); } +bool is_const_ins(instruction_ref ins) { return ins->can_eval() and not skip_propagate(ins); } void propagate_constant::apply(module& m) const { diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index fa5a7248932..64b17a96aa5 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -627,6 +627,30 @@ struct find_transpose_contiguous_reshaper_unary } }; +struct find_broadcast_transpose +{ + auto matcher() const + { + return match::name("transpose")( + match::arg(0)(match::name("multibroadcast").bind("bcast_ins"))); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + auto ins_lens = ins->get_shape().lens(); + auto bcast_ins = r.instructions["bcast_ins"]; + auto input = bcast_ins->inputs().front(); + // for now, focusing on scalar transformation + if(not input->get_shape().scalar()) + return; + + auto new_mbcast = m.insert_instruction( + bcast_ins, make_op("multibroadcast", {{"out_lens", ins_lens}}), input); + m.replace_instruction(ins, new_mbcast); + } +}; + struct find_slice_transpose { auto matcher() const @@ -799,6 +823,7 @@ void simplify_reshapes::apply(module& m) const find_nested_slice{}, find_nested_concat{}, find_transpose_slice{}, + find_broadcast_transpose{}, find_slice_transpose{}, find_transpose_contiguous_reshaper_unary{}); dead_code_elimination{}.apply(m); diff --git a/test/optimize_module_test.cpp b/test/optimize_module_test.cpp new file mode 100644 index 00000000000..d54219f1d93 --- /dev/null +++ b/test/optimize_module_test.cpp @@ -0,0 +1,65 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include +#include +#include +#include +#include + +void run_pass(migraphx::module& m) { migraphx::run_passes(m, {migraphx::optimize_module{}}); } + +TEST_CASE(broadcast_transpose_inner_broadcast) +{ + // first optimizes broadcast+transpose to just a broadcast, + // then finds inner broadcast to become mul+broadcast + migraphx::module m1; + { + auto l1 = m1.add_parameter("x", {migraphx::shape::float_type, {1}, {0}}); + auto l2 = m1.add_parameter("y", {migraphx::shape::float_type, {1}, {0}}); + auto mb1 = + m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3}}}), l1); + auto mb2 = + m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 2}}}), l2); + auto t1 = + m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), mb1); + auto mul = m1.add_instruction(migraphx::make_op("mul"), mb2, t1); + m1.add_return({mul}); + } + run_pass(m1); + migraphx::module m2; + { + auto l1 = m2.add_parameter("x", {migraphx::shape::float_type, {1}, {0}}); + auto l2 = m2.add_parameter("y", {migraphx::shape::float_type, {1}, {0}}); + auto mul = m2.add_instruction(migraphx::make_op("mul"), l2, l1); + auto mb = + m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 2}}}), mul); + m2.add_return({mb}); + } + EXPECT(m1 == m2); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index 0b16b46616a..a1679712d1d 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -68,6 +68,55 @@ migraphx::module make_concat_multibroadcast(const std::vector& in_lens, return m; } +TEST_CASE(broadcast_transpose_scalar) +{ + migraphx::module m1; + { + auto l = m1.add_parameter("x", {migraphx::shape::float_type, {1}, {0}}); + auto mb = + m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3}}}), l); + auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), mb); + m1.add_return({t1}); + } + run_pass(m1); + migraphx::module m2; + { + auto l = m2.add_parameter("x", {migraphx::shape::float_type, {1}, {0}}); + auto mb = + m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 2}}}), l); + m2.add_return({mb}); + } + + EXPECT(m1 == m2); +} + +TEST_CASE(broadcast_transpose_scalar_multi_use) +{ + // multibroadcast used more than once + migraphx::module m1; + { + auto l = m1.add_parameter("x", {migraphx::shape::float_type, {1}, {0}}); + auto mb = + m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3}}}), l); + auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), mb); + auto id = m1.add_instruction(migraphx::make_op("identity"), mb); + m1.add_return({t1, id}); + } + run_pass(m1); + migraphx::module m2; + { + auto l = m2.add_parameter("x", {migraphx::shape::float_type, {1}, {0}}); + auto mb = + m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 2}}}), l); + auto mb2 = + m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3}}}), l); + auto id = m2.add_instruction(migraphx::make_op("identity"), mb2); + m2.add_return({mb, id}); + } + + EXPECT(m1 == m2); +} + TEST_CASE(double_contig) { migraphx::program p;