Skip to content

Commit

Permalink
Merge branch 'develop' into inner_broadcast_layout
Browse files Browse the repository at this point in the history
  • Loading branch information
CharlieL7 authored Jan 21, 2025
2 parents e7c7c4b + 50c6848 commit 47194f6
Show file tree
Hide file tree
Showing 11 changed files with 190 additions and 96 deletions.
35 changes: 21 additions & 14 deletions src/onnx/parse_convolution.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -141,17 +141,14 @@ struct parse_convolution : op_parser<parse_convolution>
return all_zeros;
}

static auto
static migraphx::operation
qparam_broadcast_op(instruction_ref qparam, std::vector<std::size_t> lens, std::size_t axis)
{
if(qparam->get_shape().scalar())
if(qparam->get_shape().elements() == 1)
{
return migraphx::make_op("multibroadcast", {{"out_lens", lens}});
}
else
{
return migraphx::make_op("broadcast", {{"out_lens", lens}, {"axis", axis}});
}
return migraphx::make_op("broadcast", {{"out_lens", lens}, {"axis", axis}});
}

static instruction_ref handle_quant_bias(const operation& op,
Expand All @@ -162,27 +159,37 @@ struct parse_convolution : op_parser<parse_convolution>
const instruction_ref& w_zp,
onnx_parser::node_info& info)
{
// to handle the bias, apply the following transformation:
// conv(x-x_zp,w-w_zp) = conv(x,w) - conv(x_zp,w) - conv(x,w_zp) + conv(x_zp,w_zp)
instruction_ref ret = input;

// multibroadcast (or broadcast) zero points according to spec
// x_zp should be a scalar or literal with one element
// w_zp can be either a single element or a 1d tensor with size out_channels
migraphx::operation x_zp_bc =
migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}});
migraphx::operation w_zp_bc = qparam_broadcast_op(w_zp, weights->get_shape().lens(), 0);

if(not is_symmetric_zero_point(x_zp))
{
auto out_zp_1 = info.add_common_op(op.name(), x_zp, weights);
auto x_zp_mb = info.add_instruction(x_zp_bc, x_zp);
auto out_zp_1 = info.add_instruction(op, x_zp_mb, weights);
ret = info.add_common_op("sub", ret, out_zp_1);
}

if(not is_symmetric_zero_point(w_zp))
{
auto out_zp_2 = info.add_common_op(op.name(), x, w_zp);
auto w_zp_mb = info.add_instruction(w_zp_bc, w_zp);
auto out_zp_2 = info.add_instruction(op, x, w_zp_mb);
ret = info.add_common_op("sub", ret, out_zp_2);
}

if(not(is_symmetric_zero_point(x_zp)) and not(is_symmetric_zero_point(w_zp)))
{
auto x_zp_bc =
info.add_instruction(qparam_broadcast_op(x_zp, x->get_shape().lens(), 0), x_zp);
auto w_zp_bc = info.add_instruction(
qparam_broadcast_op(w_zp, weights->get_shape().lens(), 0), w_zp);
auto x_zp_mb = info.add_instruction(x_zp_bc, x_zp);
auto w_zp_mb = info.add_instruction(w_zp_bc, w_zp);

auto out_zp_3 = info.add_instruction(op, x_zp_bc, w_zp_bc);
auto out_zp_3 = info.add_instruction(op, x_zp_mb, w_zp_mb);

ret = info.add_common_op("add", ret, out_zp_3);
}
Expand Down
4 changes: 3 additions & 1 deletion src/simplify_algebra.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -761,6 +761,8 @@ struct find_inner_broadcast
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
if(ins->get_operator().name() == "layout")
return;
const auto& broadcasts = ins->inputs();
if(broadcasts.empty())
return;
Expand Down
2 changes: 1 addition & 1 deletion test/onnx/.onnxrt-commit
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1ce59577d54be75afdb6123cd465f18456d78611
a05203b91c96ca5ca0112f9403267c04da55ec78
6 changes: 3 additions & 3 deletions test/onnx/convinteger_bias_test.onnx
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
strides@@�convinteger_bias_testZ
0





 Z
1





Z
Expand All @@ -23,7 +23,7 @@
b
3





B
34 changes: 34 additions & 0 deletions test/onnx/convinteger_dual_bias_simple_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
 !convinteger_dual_bias_simple_test:�
B
0
1
2
34" ConvInteger*
dilations@@�*
strides@@�!convinteger_dual_bias_simple_testZ
0




Z
1




Z
2


Z
3


b
4




B
20 changes: 11 additions & 9 deletions test/onnx/convinteger_dual_bias_test.onnx
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,18 @@ B
strides@@�convinteger_dual_bias_testZ
0





Z



Z
1





Z

Z
2


Expand All @@ -28,7 +30,7 @@ B
b
4





B

B
23 changes: 20 additions & 3 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1670,10 +1670,10 @@ def convinteger_no_bias_uint8_test():

@onnx_test()
def convinteger_bias_test():
x = helper.make_tensor_value_info('0', TensorProto.INT8, [1, 3, 32, 32])
y = helper.make_tensor_value_info('1', TensorProto.INT8, [1, 3, 5, 5])
x = helper.make_tensor_value_info('0', TensorProto.INT8, [2, 3, 32, 32])
y = helper.make_tensor_value_info('1', TensorProto.INT8, [4, 3, 5, 5])
z = helper.make_tensor_value_info('2', TensorProto.INT8, [1])
out = helper.make_tensor_value_info('3', TensorProto.INT32, [1, 2, 28, 28])
out = helper.make_tensor_value_info('3', TensorProto.INT32, [2, 4, 28, 28])

node = onnx.helper.make_node('ConvInteger',
inputs=['0', '1', '2'],
Expand All @@ -1686,6 +1686,23 @@ def convinteger_bias_test():

@onnx_test()
def convinteger_dual_bias_test():
x = helper.make_tensor_value_info('0', TensorProto.INT8, [2, 3, 10, 10])
y = helper.make_tensor_value_info('1', TensorProto.INT8, [4, 3, 3, 3])
z = helper.make_tensor_value_info('2', TensorProto.INT8, [1])
w = helper.make_tensor_value_info('3', TensorProto.INT8, [1])
out = helper.make_tensor_value_info('4', TensorProto.INT32, [2, 4, 8, 8])

node = onnx.helper.make_node('ConvInteger',
inputs=['0', '1', '2', '3'],
outputs=['4'],
dilations=[1, 1],
strides=[1, 1])

return ([node], [x, y, z, w], [out])


@onnx_test()
def convinteger_dual_bias_simple_test():
x = helper.make_tensor_value_info('0', TensorProto.INT8, [1, 3, 5, 5])
y = helper.make_tensor_value_info('1', TensorProto.INT8, [1, 3, 2, 2])
z = helper.make_tensor_value_info('2', TensorProto.INT8, [1])
Expand Down
14 changes: 5 additions & 9 deletions test/onnx/parse/convinteger_bias_test.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -28,24 +28,20 @@ TEST_CASE(convinteger_bias_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto data = mm->add_parameter("0", {migraphx::shape::int8_type, {1, 3, 32, 32}});
auto weights = mm->add_parameter("1", {migraphx::shape::int8_type, {1, 3, 5, 5}});
auto data = mm->add_parameter("0", {migraphx::shape::int8_type, {2, 3, 32, 32}});
auto weights = mm->add_parameter("1", {migraphx::shape::int8_type, {4, 3, 5, 5}});
auto data_bias = mm->add_parameter("2", {migraphx::shape::int8_type, {1}, {1}});

mm->add_literal(migraphx::literal{migraphx::shape{data->get_shape().type(), {1}, {0}}, {0}});
auto quant = mm->add_instruction(migraphx::make_op("quant_convolution"), data, weights);

auto bcast_data_bias = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", weights->get_shape().lens()}}),
data_bias);
migraphx::make_op("multibroadcast", {{"out_lens", data->get_shape().lens()}}), data_bias);

auto quant2 =
mm->add_instruction(migraphx::make_op("quant_convolution"), bcast_data_bias, weights);

auto bcast_quant2 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), quant2);

mm->add_instruction(migraphx::make_op("sub"), quant, bcast_quant2);
mm->add_instruction(migraphx::make_op("sub"), quant, quant2);

auto prog = optimize_onnx("convinteger_bias_test.onnx");
EXPECT(p == prog);
Expand Down
32 changes: 14 additions & 18 deletions test/onnx/parse/convinteger_dual_bias_test.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -28,41 +28,37 @@ TEST_CASE(convinteger_dual_bias_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto data = mm->add_parameter("0", {migraphx::shape::int8_type, {1, 3, 5, 5}});
auto weight = mm->add_parameter("1", {migraphx::shape::int8_type, {1, 3, 2, 2}});
auto data = mm->add_parameter("0", {migraphx::shape::int8_type, {2, 3, 10, 10}});
auto weight = mm->add_parameter("1", {migraphx::shape::int8_type, {4, 3, 3, 3}});
auto data_bias = mm->add_parameter("2", {migraphx::shape::int8_type, {1}, {1}});
auto weight_bias = mm->add_parameter("3", {migraphx::shape::int8_type, {1}, {1}});

auto quant = mm->add_instruction(migraphx::make_op("quant_convolution"), data, weight);

auto mbcast_data_bias = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", weight->get_shape().lens()}}), data_bias);
migraphx::make_op("multibroadcast", {{"out_lens", data->get_shape().lens()}}), data_bias);

auto quant_db_w =
auto quant_mb_w =
mm->add_instruction(migraphx::make_op("quant_convolution"), mbcast_data_bias, weight);

auto quant_mb_w = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), quant_db_w);

quant = mm->add_instruction(migraphx::make_op("sub"), quant, quant_mb_w);

auto mbcast_weight_bias = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", data->get_shape().lens()}}), weight_bias);
migraphx::make_op("multibroadcast", {{"out_lens", weight->get_shape().lens()}}),
weight_bias);

auto quant_d_wb =
auto quant_md_wb =
mm->add_instruction(migraphx::make_op("quant_convolution"), data, mbcast_weight_bias);

auto quant_md_wb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), quant_d_wb);

quant = mm->add_instruction(migraphx::make_op("sub"), quant, quant_md_wb);

auto bcast_data_bias = mm->add_instruction(
migraphx::make_op("broadcast", {{"out_lens", data->get_shape().lens()}}), data_bias);
auto bcast_weight_bias = mm->add_instruction(
migraphx::make_op("broadcast", {{"out_lens", weight->get_shape().lens()}}), weight_bias);
mbcast_data_bias = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", data->get_shape().lens()}}), data_bias);
mbcast_weight_bias = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", weight->get_shape().lens()}}),
weight_bias);
auto bias_quant = mm->add_instruction(
migraphx::make_op("quant_convolution"), bcast_data_bias, bcast_weight_bias);
migraphx::make_op("quant_convolution"), mbcast_data_bias, mbcast_weight_bias);

mm->add_instruction(migraphx::make_op("add"), quant, bias_quant);

Expand Down
Loading

0 comments on commit 47194f6

Please sign in to comment.