Skip to content

Commit

Permalink
Catch invalid broadcasts in pointwise reduce fusion (#3659)
Browse files Browse the repository at this point in the history
  • Loading branch information
shivadbhavsar authored Nov 29, 2024
1 parent 8503b31 commit 241e24e
Show file tree
Hide file tree
Showing 4 changed files with 221 additions and 1 deletion.
43 changes: 42 additions & 1 deletion src/fuse_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,53 @@ static auto any_input(Ms... ms)
return match::any_of[match::inputs()](match::any(ms...).bind("input"));
}

bool is_valid_broadcast(const instruction_ref b, const std::vector<size_t>& reduce_axes)
{
std::vector<size_t> broadcast_axes;
auto bstrides = b->get_shape().strides();

for(size_t i = 0; i < bstrides.size(); ++i)
{
if(bstrides.at(i) == 0)
broadcast_axes.push_back(i);
}

return broadcast_axes == reduce_axes;
}

template <class M>
static auto match_broadcast_axes(M m)
{
return match::make_basic_fun_matcher(
[=](match::matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
optional<instruction_ref> result = m.match(ctx, ins);
if(contains(ctx.instructions, "broadcast"))
{
instruction_ref reduce;
if(ins->get_operator().name() == "fused_reduce")
{
reduce = ins;
}
else
{
assert(contains(ctx.instructions, "reduce"));
reduce = ctx.instructions["reduce"];
}
auto axes = reduce->get_operator().to_value().at("axes").to_vector<size_t>();
auto broadcast = ctx.instructions["broadcast"];
if(not is_valid_broadcast(broadcast, axes))
return nullopt;
}
return result;
});
}

static auto match_broadcastable_input(const std::string& op, const std::string& name)
{
auto match_op = match::name(op)(used_once_except_broadcast()).bind(name);
auto match_op_input = any_input(match_op, match::used_once());
auto broadcast_match_op_input = any_input(match_broadcast(match_op), match::used_once());
return match::any_of(match_op_input, broadcast_match_op_input);
return match::any_of(match_op_input, match_broadcast_axes(broadcast_match_op_input));
}

static void finalize_reduce_module(module_ref m)
Expand Down
126 changes: 126 additions & 0 deletions test/fuse_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,44 @@ TEST_CASE(pointwise_reduce)
EXPECT(p1 == p2);
}

TEST_CASE(pointwise_reduce_unfusable_broadcast)
{
migraphx::shape s{migraphx::shape::float_type, {2, 1, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add = add_pointwise(p1, "main:pointwise0", {x, y}, single_pointwise("add"));
auto addb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 3}}}), add);
auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), addb);
mm->add_return({rsum});
}
run_pass(p1);

migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add = add_pointwise(p2, "main:pointwise0", {x, y}, single_pointwise("add"));
auto addb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 3}}}), add);
auto rsum =
add_reduce(p2,
"main:reduce_sum0",
{addb},
{2},
[&](auto* rm, const auto& inputs, const auto& axes) {
return rm->add_instruction(
migraphx::make_op("reduce_sum", {{"axes", axes}}), inputs[0]);
});
mm->add_return({rsum});
}
EXPECT(p1 == p2);
}

TEST_CASE(scalar_multibroadcast)
{
// Matches the find_pointwise_reduce matcher, but input x has a (scalar) shape
Expand Down Expand Up @@ -283,6 +321,43 @@ TEST_CASE(reduce_pointwise)
EXPECT(p1 == p2);
}

TEST_CASE(reduce_pointwise_unfusable_broadcast)
{
migraphx::shape s{migraphx::shape::float_type, {2, 1, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), x);
auto rsumb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 3}}}), rsum);
auto yb =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 3}}}), y);
auto add = add_pointwise(p1, "main:pointwise0", {rsumb, yb}, single_pointwise("add"));
mm->add_return({add});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto rsum = add_reduce(
p2, "main:reduce_sum0", {x}, {2}, [&](auto* rm, const auto& inputs, const auto& axes) {
return rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}),
inputs[0]);
});
auto rsumb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 3}}}), rsum);
auto yb =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 3}}}), y);
auto add = add_pointwise(p2, "main:pointwise0", {rsumb, yb}, single_pointwise("add"));
mm->add_return({add});
}
EXPECT(p1 == p2);
}

TEST_CASE(reduce_reduce)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
Expand Down Expand Up @@ -325,6 +400,57 @@ TEST_CASE(reduce_reduce)
EXPECT(p1 == p2);
}

TEST_CASE(reduce_reduce_unfusable_broadcast)
{
migraphx::shape s{migraphx::shape::float_type, {2, 1, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), x);
auto rsumb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 3}}}), rsum);
auto xb =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 3}}}), x);
auto rsumdiff = add_pointwise(p1, "main:pointwise0", {rsumb, xb}, single_pointwise("sub"));
auto rsum2 =
mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), rsumdiff);
auto sqrt = add_pointwise(p1, "main:pointwise1", {rsum2}, single_pointwise("sqrt"));
mm->add_return({sqrt});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto rsum = add_reduce(
p2, "main:reduce_sum0", {x}, {2}, [&](auto* rm, const auto& inputs, const auto& axes) {
return rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}),
inputs[0]);
});

auto rsumb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 3}}}), rsum);
auto xb =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 3}}}), x);

auto sqrt = add_reduce(
p2,
"main:pointwise0:main:reduce_sum1:main:pointwise1",
{rsumb, xb},
{2},
[&](auto* rm, const auto& inputs, const auto& axes) {
auto rsumdiff = add_pointwise(
p2, rm, "main:pointwise0", {inputs[0], inputs[1]}, single_pointwise("sub"));
auto rsum2 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}),
rsumdiff);
return add_pointwise(p2, rm, "main:pointwise1", {rsum2}, single_pointwise("sqrt"));
});
mm->add_return({sqrt});
}
EXPECT(p1 == p2);
}

TEST_CASE(parallel_reduce_reduce)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
Expand Down
40 changes: 40 additions & 0 deletions test/include/layernorm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,43 @@ inline migraphx::instruction_ref add_layernorm(migraphx::module& m,
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), bias);
return m.add_instruction(migraphx::make_op("add"), mul, bias_mbcast);
}

inline migraphx::instruction_ref add_pointwise_layernorm(migraphx::module& m,
migraphx::instruction_ref x,
const std::vector<size_t>& dims,
float eps = 1e-12f)
{
auto mgx_type = x->get_shape().type();
auto scale = m.add_parameter("scale", migraphx::shape{mgx_type, {1, 1, dims.back()}});
auto bias = m.add_parameter("bias", migraphx::shape{mgx_type, {1, 1, dims.back()}});

auto epsilon = m.add_literal(migraphx::literal{migraphx::shape{mgx_type}, {eps}});
auto one = m.add_literal(migraphx::literal{migraphx::shape{mgx_type}, {1}});

auto mean = m.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), x);
auto mean_mbcast =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), mean);
auto x_minus_mean = m.add_instruction(migraphx::make_op("sub"), x, mean_mbcast);
auto sqdiff = m.add_instruction(migraphx::make_op("sqdiff"), x, mean_mbcast);
auto var = m.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), sqdiff);

auto epsilon_mbcast = m.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", var->get_shape().lens()}}), epsilon);
auto var_stable = m.add_instruction(migraphx::make_op("add"), var, epsilon_mbcast);
auto inv_stddev_x = m.add_instruction(migraphx::make_op("rsqrt"), var_stable);
auto inv_stddev_x_mbcast =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), inv_stddev_x);
auto norm = m.add_instruction(migraphx::make_op("mul"), x_minus_mean, inv_stddev_x_mbcast);

auto one_mbcast = m.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", scale->get_shape().lens()}}), one);
auto add_scale = m.add_instruction(migraphx::make_op("add"), scale, one_mbcast);
auto scale_mbcast =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), add_scale);
auto scale_norm = m.add_instruction(migraphx::make_op("mul"), norm, scale_mbcast);

auto bias_mbcast =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), bias);

return m.add_instruction(migraphx::make_op("add"), scale_norm, bias_mbcast);
}
13 changes: 13 additions & 0 deletions test/verify/test_layernorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,16 @@ struct test_add_layernorm_add_gemm_nonstd : verify_program<test_add_layernorm_ad
}
std::string section() const { return "gemm"; }
};

struct test_pw_layernorm : verify_program<test_pw_layernorm>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<size_t> dims = {1, 9, 6};
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, dims});
add_pointwise_layernorm(*mm, x, dims);
return p;
}
};

0 comments on commit 241e24e

Please sign in to comment.