diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index 326b5ef7752..ae6c68f445a 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -266,6 +266,89 @@ struct find_mul_dot } }; +/* +Moves the slice on the output of the Dot operation to slices on the inputs of the Dot operation to +avoid computing redundant values. +e.g. slice(gemm(a, b)) --> gemm(slice(a), slice(b)) +*/ +struct find_dot_slice +{ + auto matcher() const + { + return match::name("slice")( + match::args(match::name("dot", "quant_dot")(match::used_once()).bind("dot_ins"))); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto slice_ins = r.result; + auto dot_ins = r.instructions["dot_ins"]; + auto slice_op = slice_ins->normalized_operator().to_value(); + auto axes = slice_op["axes"].to_vector(); + auto starts = slice_op["starts"].to_vector(); + auto ends = slice_op["ends"].to_vector(); + assert(starts.size() == ends.size() and starts.size() == axes.size()); + auto has_neg_vals = [](auto vec) { + return std::any_of(vec.begin(), vec.end(), [](auto i) { return i < 0; }); + }; + if(has_neg_vals(starts) or has_neg_vals(ends) or has_neg_vals(axes)) + { + MIGRAPHX_THROW("FIND_DOT_SLICE: slice is not normalized."); + } + auto dot_inputs = dot_ins->inputs(); + auto num_batch_dims = dot_ins->get_shape().lens().size() - 2; + std::vector slice_axes_1, starts_1, ends_1; // NOLINT + std::vector slice_axes_2, starts_2, ends_2; // NOLINT + for(auto i : range(axes.size())) + { + if(axes[i] < num_batch_dims) + { + slice_axes_1.push_back(axes[i]); + starts_1.push_back(starts[i]); + ends_1.push_back(ends[i]); + slice_axes_2.push_back(axes[i]); + starts_2.push_back(starts[i]); + ends_2.push_back(ends[i]); + } + else if(axes[i] == num_batch_dims) + { + slice_axes_1.push_back(axes[i]); + starts_1.push_back(starts[i]); + ends_1.push_back(ends[i]); + } + else if(axes[i] == num_batch_dims + 1) + { + slice_axes_2.push_back(axes[i]); + starts_2.push_back(starts[i]); + ends_2.push_back(ends[i]); + } + else + { + MIGRAPHX_THROW("FIND_DOT_SLICE: invalid case"); + } + } + auto slice_1 = dot_inputs.at(0); + if(not slice_axes_1.empty()) + { + slice_1 = m.insert_instruction( + slice_ins, + migraphx::make_op("slice", + {{"axes", slice_axes_1}, {"starts", starts_1}, {"ends", ends_1}}), + dot_inputs.at(0)); + } + auto slice_2 = dot_inputs.at(1); + if(not slice_axes_2.empty()) + { + slice_2 = m.insert_instruction( + slice_ins, + migraphx::make_op("slice", + {{"axes", slice_axes_2}, {"starts", starts_2}, {"ends", ends_2}}), + dot_inputs.at(1)); + } + m.replace_instruction(slice_ins, dot_ins->get_operator(), {slice_1, slice_2}); + } +}; + struct find_dot_mul { auto matcher() const @@ -1896,6 +1979,7 @@ void simplify_algebra::apply(module& m) const find_mul_conv{}, find_mul_slice_conv{}, find_mul_dot{}, + find_dot_slice{}, find_dot_mul{}, find_mul_add{}, find_unit_ops{}, diff --git a/test/simplify_algebra_test.cpp b/test/simplify_algebra_test.cpp index c4f7bd036b5..350952adb14 100644 --- a/test/simplify_algebra_test.cpp +++ b/test/simplify_algebra_test.cpp @@ -4081,6 +4081,173 @@ TEST_CASE(dot_mul_b_non_const) EXPECT(m1.sort() == m2.sort()); } +TEST_CASE(dot_slice_b) +{ + migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}}; + migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}}; + migraphx::module m1; + { + auto a = m1.add_parameter("a", as); + auto b = m1.add_parameter("b", bs); + auto dot = m1.add_instruction(migraphx::make_op("dot"), a, b); + auto slice = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {64}}, {"ends", {128}}}), dot); + m1.add_return({slice}); + }; + run_pass(m1); + migraphx::module m2; + { + auto a = m2.add_parameter("a", as); + auto b = m2.add_parameter("b", bs); + auto b_slice = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {64}}, {"ends", {128}}}), b); + auto dot = m2.add_instruction(migraphx::make_op("dot"), a, b_slice); + m2.add_return({dot}); + }; + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(dot_slice_a) +{ + migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}}; + migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}}; + migraphx::module m1; + { + auto a = m1.add_parameter("a", as); + auto b = m1.add_parameter("b", bs); + auto dot = m1.add_instruction(migraphx::make_op("dot"), a, b); + auto slice = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {64}}, {"ends", {128}}}), dot); + m1.add_return({slice}); + }; + run_pass(m1); + migraphx::module m2; + { + auto a = m2.add_parameter("a", as); + auto b = m2.add_parameter("b", bs); + auto a_slice = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {64}}, {"ends", {128}}}), a); + auto dot = m2.add_instruction(migraphx::make_op("dot"), a_slice, b); + m2.add_return({dot}); + }; + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(dot_slice_ab) +{ + migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}}; + migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}}; + migraphx::module m1; + { + auto a = m1.add_parameter("a", as); + auto b = m1.add_parameter("b", bs); + auto dot = m1.add_instruction(migraphx::make_op("dot"), a, b); + auto slice = m1.add_instruction( + migraphx::make_op("slice", + {{"axes", {1, 2}}, {"starts", {64, 64}}, {"ends", {128, 128}}}), + dot); + m1.add_return({slice}); + }; + run_pass(m1); + migraphx::module m2; + { + auto a = m2.add_parameter("a", as); + auto b = m2.add_parameter("b", bs); + auto a_slice = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {64}}, {"ends", {128}}}), a); + auto b_slice = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {64}}, {"ends", {128}}}), b); + auto dot = m2.add_instruction(migraphx::make_op("dot"), a_slice, b_slice); + m2.add_return({dot}); + }; + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(dot_slice_batch_dims) +{ + migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}}; + migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}}; + migraphx::module m1; + { + auto a = m1.add_parameter("a", as); + auto b = m1.add_parameter("b", bs); + auto dot = m1.add_instruction(migraphx::make_op("dot"), a, b); + auto slice = m1.add_instruction( + migraphx::make_op( + "slice", {{"axes", {0, 1, 2}}, {"starts", {0, 64, 64}}, {"ends", {1, 128, 128}}}), + dot); + m1.add_return({slice}); + }; + run_pass(m1); + migraphx::module m2; + { + auto a = m2.add_parameter("a", as); + auto b = m2.add_parameter("b", bs); + auto a_slice = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {0, 64}}, {"ends", {1, 128}}}), + a); + auto b_slice = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0, 2}}, {"starts", {0, 64}}, {"ends", {1, 128}}}), + b); + auto dot = m2.add_instruction(migraphx::make_op("dot"), a_slice, b_slice); + m2.add_return({dot}); + }; + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(dot_slice_not_applicable_1) +{ + migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}}; + migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}}; + migraphx::module m1; + { + auto a = m1.add_parameter("a", as); + auto b = m1.add_parameter("b", bs); + auto dot = m1.add_instruction(migraphx::make_op("dot"), a, b); + auto slice1 = m1.add_instruction( + migraphx::make_op("slice", + {{"axes", {1, 2}}, {"starts", {64, 64}}, {"ends", {128, 128}}}), + dot); + auto slice2 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), dot); + + m1.add_return({slice1, slice2}); + }; + migraphx::module m2 = m1; + run_pass(m1); + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(dot_slice_not_applicable_2) +{ + migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}}; + migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}}; + migraphx::module m1; + { + auto a = m1.add_parameter("a", as); + auto b = m1.add_parameter("b", bs); + auto dot = m1.add_instruction(migraphx::make_op("dot"), a, b); + auto slice = m1.add_instruction( + migraphx::make_op("slice", + {{"axes", {-2, -1}}, {"starts", {64, 64}}, {"ends", {128, 128}}}), + dot); + m1.add_return({slice}); + }; + run_pass(m1); + migraphx::module m2; + { + auto a = m2.add_parameter("a", as); + auto b = m2.add_parameter("b", bs); + auto a_slice = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {64}}, {"ends", {128}}}), a); + auto b_slice = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {64}}, {"ends", {128}}}), b); + auto dot = m2.add_instruction(migraphx::make_op("dot"), a_slice, b_slice); + m2.add_return({dot}); + }; + EXPECT(m1.sort() == m2.sort()); +} + TEST_CASE(conv_concat) { migraphx::shape xs{migraphx::shape::float_type, {1, 8, 4, 4}};