Skip to content

Commit

Permalink
Find dot slice (#3268)
Browse files Browse the repository at this point in the history
  • Loading branch information
umangyadav authored and TedThemistokleous committed Aug 21, 2024
1 parent 3831ec4 commit 262f79f
Show file tree
Hide file tree
Showing 2 changed files with 251 additions and 0 deletions.
84 changes: 84 additions & 0 deletions src/simplify_algebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,89 @@ struct find_mul_dot
}
};

/*
Moves the slice on the output of the Dot operation to slices on the inputs of the Dot operation to
avoid computing redundant values.
e.g. slice(gemm(a, b)) --> gemm(slice(a), slice(b))
*/
struct find_dot_slice
{
auto matcher() const
{
return match::name("slice")(
match::args(match::name("dot", "quant_dot")(match::used_once()).bind("dot_ins")));
}

void apply(module& m, const match::matcher_result& r) const
{
auto slice_ins = r.result;
auto dot_ins = r.instructions["dot_ins"];
auto slice_op = slice_ins->normalized_operator().to_value();
auto axes = slice_op["axes"].to_vector<int64_t>();
auto starts = slice_op["starts"].to_vector<int64_t>();
auto ends = slice_op["ends"].to_vector<int64_t>();
assert(starts.size() == ends.size() and starts.size() == axes.size());
auto has_neg_vals = [](auto vec) {
return std::any_of(vec.begin(), vec.end(), [](auto i) { return i < 0; });
};
if(has_neg_vals(starts) or has_neg_vals(ends) or has_neg_vals(axes))
{
MIGRAPHX_THROW("FIND_DOT_SLICE: slice is not normalized.");
}
auto dot_inputs = dot_ins->inputs();
auto num_batch_dims = dot_ins->get_shape().lens().size() - 2;
std::vector<int64_t> slice_axes_1, starts_1, ends_1; // NOLINT
std::vector<int64_t> slice_axes_2, starts_2, ends_2; // NOLINT
for(auto i : range(axes.size()))
{
if(axes[i] < num_batch_dims)
{
slice_axes_1.push_back(axes[i]);
starts_1.push_back(starts[i]);
ends_1.push_back(ends[i]);
slice_axes_2.push_back(axes[i]);
starts_2.push_back(starts[i]);
ends_2.push_back(ends[i]);
}
else if(axes[i] == num_batch_dims)
{
slice_axes_1.push_back(axes[i]);
starts_1.push_back(starts[i]);
ends_1.push_back(ends[i]);
}
else if(axes[i] == num_batch_dims + 1)
{
slice_axes_2.push_back(axes[i]);
starts_2.push_back(starts[i]);
ends_2.push_back(ends[i]);
}
else
{
MIGRAPHX_THROW("FIND_DOT_SLICE: invalid case");
}
}
auto slice_1 = dot_inputs.at(0);
if(not slice_axes_1.empty())
{
slice_1 = m.insert_instruction(
slice_ins,
migraphx::make_op("slice",
{{"axes", slice_axes_1}, {"starts", starts_1}, {"ends", ends_1}}),
dot_inputs.at(0));
}
auto slice_2 = dot_inputs.at(1);
if(not slice_axes_2.empty())
{
slice_2 = m.insert_instruction(
slice_ins,
migraphx::make_op("slice",
{{"axes", slice_axes_2}, {"starts", starts_2}, {"ends", ends_2}}),
dot_inputs.at(1));
}
m.replace_instruction(slice_ins, dot_ins->get_operator(), {slice_1, slice_2});
}
};

struct find_dot_mul
{
auto matcher() const
Expand Down Expand Up @@ -1896,6 +1979,7 @@ void simplify_algebra::apply(module& m) const
find_mul_conv{},
find_mul_slice_conv{},
find_mul_dot{},
find_dot_slice{},
find_dot_mul{},
find_mul_add{},
find_unit_ops{},
Expand Down
167 changes: 167 additions & 0 deletions test/simplify_algebra_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4081,6 +4081,173 @@ TEST_CASE(dot_mul_b_non_const)
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(dot_slice_b)
{
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}};
migraphx::module m1;
{
auto a = m1.add_parameter("a", as);
auto b = m1.add_parameter("b", bs);
auto dot = m1.add_instruction(migraphx::make_op("dot"), a, b);
auto slice = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {64}}, {"ends", {128}}}), dot);
m1.add_return({slice});
};
run_pass(m1);
migraphx::module m2;
{
auto a = m2.add_parameter("a", as);
auto b = m2.add_parameter("b", bs);
auto b_slice = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {64}}, {"ends", {128}}}), b);
auto dot = m2.add_instruction(migraphx::make_op("dot"), a, b_slice);
m2.add_return({dot});
};
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(dot_slice_a)
{
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}};
migraphx::module m1;
{
auto a = m1.add_parameter("a", as);
auto b = m1.add_parameter("b", bs);
auto dot = m1.add_instruction(migraphx::make_op("dot"), a, b);
auto slice = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {64}}, {"ends", {128}}}), dot);
m1.add_return({slice});
};
run_pass(m1);
migraphx::module m2;
{
auto a = m2.add_parameter("a", as);
auto b = m2.add_parameter("b", bs);
auto a_slice = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {64}}, {"ends", {128}}}), a);
auto dot = m2.add_instruction(migraphx::make_op("dot"), a_slice, b);
m2.add_return({dot});
};
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(dot_slice_ab)
{
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}};
migraphx::module m1;
{
auto a = m1.add_parameter("a", as);
auto b = m1.add_parameter("b", bs);
auto dot = m1.add_instruction(migraphx::make_op("dot"), a, b);
auto slice = m1.add_instruction(
migraphx::make_op("slice",
{{"axes", {1, 2}}, {"starts", {64, 64}}, {"ends", {128, 128}}}),
dot);
m1.add_return({slice});
};
run_pass(m1);
migraphx::module m2;
{
auto a = m2.add_parameter("a", as);
auto b = m2.add_parameter("b", bs);
auto a_slice = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {64}}, {"ends", {128}}}), a);
auto b_slice = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {64}}, {"ends", {128}}}), b);
auto dot = m2.add_instruction(migraphx::make_op("dot"), a_slice, b_slice);
m2.add_return({dot});
};
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(dot_slice_batch_dims)
{
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}};
migraphx::module m1;
{
auto a = m1.add_parameter("a", as);
auto b = m1.add_parameter("b", bs);
auto dot = m1.add_instruction(migraphx::make_op("dot"), a, b);
auto slice = m1.add_instruction(
migraphx::make_op(
"slice", {{"axes", {0, 1, 2}}, {"starts", {0, 64, 64}}, {"ends", {1, 128, 128}}}),
dot);
m1.add_return({slice});
};
run_pass(m1);
migraphx::module m2;
{
auto a = m2.add_parameter("a", as);
auto b = m2.add_parameter("b", bs);
auto a_slice = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {0, 64}}, {"ends", {1, 128}}}),
a);
auto b_slice = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0, 2}}, {"starts", {0, 64}}, {"ends", {1, 128}}}),
b);
auto dot = m2.add_instruction(migraphx::make_op("dot"), a_slice, b_slice);
m2.add_return({dot});
};
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(dot_slice_not_applicable_1)
{
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}};
migraphx::module m1;
{
auto a = m1.add_parameter("a", as);
auto b = m1.add_parameter("b", bs);
auto dot = m1.add_instruction(migraphx::make_op("dot"), a, b);
auto slice1 = m1.add_instruction(
migraphx::make_op("slice",
{{"axes", {1, 2}}, {"starts", {64, 64}}, {"ends", {128, 128}}}),
dot);
auto slice2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), dot);

m1.add_return({slice1, slice2});
};
migraphx::module m2 = m1;
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(dot_slice_not_applicable_2)
{
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}};
migraphx::module m1;
{
auto a = m1.add_parameter("a", as);
auto b = m1.add_parameter("b", bs);
auto dot = m1.add_instruction(migraphx::make_op("dot"), a, b);
auto slice = m1.add_instruction(
migraphx::make_op("slice",
{{"axes", {-2, -1}}, {"starts", {64, 64}}, {"ends", {128, 128}}}),
dot);
m1.add_return({slice});
};
run_pass(m1);
migraphx::module m2;
{
auto a = m2.add_parameter("a", as);
auto b = m2.add_parameter("b", bs);
auto a_slice = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {64}}, {"ends", {128}}}), a);
auto b_slice = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {64}}, {"ends", {128}}}), b);
auto dot = m2.add_instruction(migraphx::make_op("dot"), a_slice, b_slice);
m2.add_return({dot});
};
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(conv_concat)
{
migraphx::shape xs{migraphx::shape::float_type, {1, 8, 4, 4}};
Expand Down

0 comments on commit 262f79f

Please sign in to comment.