Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resize onnx operator: Optimization for Compute and Space performance of its linear option. #3773

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 57 additions & 29 deletions src/onnx/parse_resize.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -35,11 +35,15 @@

static std::vector<int>
calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& vvv_ind,
int i_dim,
size_t i_dim, // input lens index
lakhinderwalia marked this conversation as resolved.
Show resolved Hide resolved
size_t r_dim, // resized index
lakhinderwalia marked this conversation as resolved.
Show resolved Hide resolved
std::vector<std::vector<std::size_t>> vec_dims,
const shape& in_s)
const shape& in_s,
const shape& out_s)
{
if(i_dim == vvv_ind.size())
auto&& o_lens = out_s.lens();

if(i_dim == o_lens.size())
{
std::vector<int> vec_ind(vec_dims.size());
std::transform(vec_dims.begin(), vec_dims.end(), vec_ind.begin(), [&](auto idx) {
Expand All @@ -48,7 +52,19 @@
return vec_ind;
}

const auto& vv_lo = vvv_ind[i_dim][0];
size_t o_dim_size = o_lens[i_dim];

// when a dimension is unchanged in Resize, its indices are identical, and just copied:
if(in_s.lens()[i_dim] == o_dim_size)
{
for(std::size_t v_idx = 0; v_idx < vec_dims.size(); v_idx += o_dim_size)
for(size_t i = 0; i < o_dim_size; i++)
vec_dims[v_idx + i].push_back(i);
return calc_neighbor_points(vvv_ind, i_dim + 1, r_dim, std::move(vec_dims), in_s, out_s);
}

// when a dimension is unchanged in Resize, its indices are processed below:
const auto& vv_lo = vvv_ind[r_dim][0];
std::vector<std::vector<std::size_t>> vec_dims1;
for(std::size_t start = 0; start < vec_dims.size(); start += vv_lo.size())
{
Expand All @@ -62,7 +78,7 @@
});
}

const auto& vv_hi = vvv_ind[i_dim][1];
const auto& vv_hi = vvv_ind[r_dim][1];
for(std::size_t start = 0; start < vec_dims.size(); start += vv_hi.size())
{
std::transform(vv_hi.begin(),
Expand All @@ -75,7 +91,7 @@
});
}
vec_dims.clear();
return calc_neighbor_points(vvv_ind, i_dim + 1, std::move(vec_dims1), in_s);
return calc_neighbor_points(vvv_ind, i_dim + 1, r_dim + 1, std::move(vec_dims1), in_s, out_s);
}

static std::string get_coord_trans_mode(const onnx_parser::attribute_map& attr)
Expand Down Expand Up @@ -350,7 +366,6 @@
": linear mode not supported for non-constant inputs");

shape out_s{in_s.type(), out_lens};
std::size_t out_elements = out_s.elements();

// reshape input to one-dimension
std::vector<int64_t> rsp_lens = {static_cast<int64_t>(in_s.elements())};
Expand All @@ -359,41 +374,55 @@
auto nearest_floor = op::resize::get_nearest_op("floor");
auto nearest_ceil = op::resize::get_nearest_op("ceil");

// get the number of dimensions
std::size_t n_dim = out_lens.size();
std::size_t n_dim = out_lens.size();
lakhinderwalia marked this conversation as resolved.
Show resolved Hide resolved
std::size_t r_dim = 0; // count: lens dimensions that are resized
lakhinderwalia marked this conversation as resolved.
Show resolved Hide resolved
std::size_t out_elements = 1; // count: number of elements to fix due to resize.
for(std::size_t dim = 0; dim < n_dim; dim++)
{
if(in_lens[dim] == out_lens[dim])
continue;
r_dim++;
out_elements *= out_lens[dim];
}
std::vector<std::vector<std::size_t>> vv_ind(2, std::vector<std::size_t>(out_elements));
std::vector<std::vector<std::vector<std::size_t>>> vvv_ind(n_dim, vv_ind);
std::vector<std::vector<float>> delta(n_dim, std::vector<float>(out_elements));
std::vector<std::vector<std::vector<std::size_t>>> vvv_ind(r_dim, vv_ind);
lakhinderwalia marked this conversation as resolved.
Show resolved Hide resolved
std::vector<std::vector<float>> delta(r_dim, std::vector<float>(out_elements));

shape_for_each(out_s, [&](const auto& out_idx_v, size_t out_idx) {
for(auto ii = 0; ii < in_lens.size(); ++ii)
shape_for_each(out_s, [&](const auto& out_idx_v, std::size_t out_idx) {
std::size_t ii = 0;
for(std::size_t idx = 0; idx < in_lens.size(); ++idx)
{
auto idx_val = idx_op(in_lens[ii], out_lens[ii], out_idx_v[ii], vec_scale[ii]);
vvv_ind[ii][0][out_idx] = nearest_floor(in_lens[ii], idx_val);
vvv_ind[ii][1][out_idx] = nearest_ceil(in_lens[ii], idx_val);
if(in_lens[idx] == out_lens[idx])
continue;
auto idx_val =
idx_op(in_lens[idx], out_lens[idx], out_idx_v[idx], vec_scale[idx]);
vvv_ind[ii][0][out_idx] = nearest_floor(in_lens[idx], idx_val);
vvv_ind[ii][1][out_idx] = nearest_ceil(in_lens[idx], idx_val);
delta[ii][out_idx] = idx_val - vvv_ind[ii][0][out_idx];
++ii;
}
});

auto ind = calc_neighbor_points(
vvv_ind, 0, std::vector<std::vector<std::size_t>>(out_elements), in_s);
auto ind_lens = out_lens;
ind_lens[0] *= (std::size_t{1} << n_dim);
shape ind_s{shape::int32_type, ind_lens};
vvv_ind, 0, 0, std::vector<std::vector<std::size_t>>(out_elements), in_s, out_s);

auto dim_lens = out_lens;
dim_lens[0] *= (1u << r_dim);
lakhinderwalia marked this conversation as resolved.
Show resolved Hide resolved
shape ind_s{shape::int32_type, dim_lens};
auto ins_ind = info.add_literal(literal(ind_s, ind));
auto data = info.add_instruction(make_op("gather", {{"axis", 0}}), rsp, ins_ind);

auto dim_lens = out_lens;
dim_lens[0] *= (std::size_t{1} << (n_dim - 1));
for(std::size_t i = 0; i < n_dim; ++i)
std::size_t lens_idx = out_lens.size() - 1;
for(std::size_t i = 0; i < r_dim; lens_idx--)
{
if(in_lens[lens_idx] == out_lens[lens_idx])
continue;

Check warning on line 419 in src/onnx/parse_resize.cpp

View check run for this annotation

Codecov / codecov/patch

src/onnx/parse_resize.cpp#L419

Added line #L419 was not covered by tests
dim_lens[0] /= 2; // halved for 2 slices of data
shape dim_s{shape::float_type, dim_lens};
const auto& dim_delta = delta[n_dim - i - 1];
const auto& dim_delta = delta[r_dim - i - 1];
std::vector<float> delta_data;
for(std::size_t j = 0; j < dim_lens[0] / out_lens[0]; ++j)
{
delta_data.insert(delta_data.begin(), dim_delta.begin(), dim_delta.end());
}
auto ins_delta = info.add_literal(dim_s, delta_data);

// slice the data
Expand All @@ -408,9 +437,8 @@
auto diff = info.add_instruction(make_op("sub"), hi, low);
auto ddf = info.add_instruction(make_op("mul"), diff, ins_delta);
data = info.add_instruction(make_op("add"), ddf, low);
dim_lens[0] /= 2;
i++;
}

return data;
}
}
Expand Down
14 changes: 14 additions & 0 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -11457,6 +11457,20 @@ def resize_upsample_linear_test():
return ([node], [X], [Y], [scales_tensor])


@onnx_test()
def resize_upsample_linear_large_test():
x = helper.make_tensor_value_info('X', TensorProto.FLOAT,
[1, 1, 1024, 1024])
s = helper.make_tensor('scales', TensorProto.FLOAT, [4], [1, 1, 2, 2])
y = helper.make_tensor_value_info('Y', TensorProto.FLOAT,
[1, 1, 2048, 2048])
node = onnx.helper.make_node('Resize',
inputs=['X', '', 'scales'],
outputs=['Y'],
mode='linear')
return ([node], [x], [y], [s])


@onnx_test()
def resize_upsample_pf_test():
scales = np.array([1.0, 1.0, 2.0, 3.0], dtype=np.float32)
Expand Down
113 changes: 56 additions & 57 deletions test/onnx/include/onnx_test_utils.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -325,6 +325,29 @@ inline migraphx::program make_quantizelinear_axis_prog()
return p;
}

/* Parsed IR equivalent of create_upsample_linear_prog()
module: "main"
@0 = @literal{ ... } -> float_type, {1, 1, 4, 4}, {16, 16, 4, 1}
@1 = @literal{ ... } -> float_type, {2, 1, 4, 4}, {16, 16, 4, 1}
@2 = @literal{ ... } -> int32_type, {4, 1, 4, 4}, {16, 16, 4, 1}
X = @param:X -> float_type, {1, 1, 2, 2}, {4, 4, 2, 1}
@4 = @literal{1, 1, 2, 2} -> float_type, {4}, {1}
@5 = undefined -> float_type, {}, {}
@6 = reshape[dims={4}](X) -> float_type, {4}, {1}
@7 = gather[axis=0](@6,@2) -> float_type, {4, 1, 4, 4}, {16, 16, 4, 1}
@8 = slice[axes={0},starts={0},ends={2}](@7) -> float_type, {2, 1, 4, 4}, {16, 16, 4, 1}
@9 = slice[axes={0},starts={2},ends={4}](@7) -> float_type, {2, 1, 4, 4}, {16, 16, 4, 1}
@10 = sub(@9,@8) -> float_type, {2, 1, 4, 4}, {16, 16, 4, 1}
@11 = mul(@10,@1) -> float_type, {2, 1, 4, 4}, {16, 16, 4, 1}
@12 = add(@11,@8) -> float_type, {2, 1, 4, 4}, {16, 16, 4, 1}
@13 = slice[axes={0},starts={0},ends={1}](@12) -> float_type, {1, 1, 4, 4}, {16, 16, 4, 1}
@14 = slice[axes={0},starts={1},ends={2}](@12) -> float_type, {1, 1, 4, 4}, {16, 16, 4, 1}
@15 = sub(@14,@13) -> float_type, {1, 1, 4, 4}, {16, 16, 4, 1}
@16 = mul(@15,@0) -> float_type, {1, 1, 4, 4}, {16, 16, 4, 1}
@17 = add(@16,@13) -> float_type, {1, 1, 4, 4}, {16, 16, 4, 1}
@18 = @return(@17)
*/

inline auto create_upsample_linear_prog()
{
migraphx::program p;
Expand All @@ -335,75 +358,51 @@ inline auto create_upsample_linear_prog()

migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}};
auto x = mm->add_parameter("X", sx);
migraphx::shape s_ind{migraphx::shape::int32_type, {16, 1, 4, 4}};
std::vector<int> d_ind = {
0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 2, 2, 2, 3, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 2,
2, 2, 3, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 2, 2, 2, 3, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0,
0, 1, 2, 2, 2, 3, 0, 0, 0, 1, 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 3, 0, 0, 0, 1, 2, 2, 2,
3, 2, 2, 2, 3, 2, 2, 2, 3, 0, 0, 0, 1, 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 3, 0, 0, 0, 1,
2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 3, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 2, 3, 3, 3, 0,
1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 2, 3, 3, 3, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 2, 3,
3, 3, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 2, 3, 3, 3, 0, 1, 1, 1, 2, 3, 3, 3, 2, 3, 3,
3, 2, 3, 3, 3, 0, 1, 1, 1, 2, 3, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3, 0, 1, 1, 1, 2, 3, 3, 3,
2, 3, 3, 3, 2, 3, 3, 3, 0, 1, 1, 1, 2, 3, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3};
auto l_ind = mm->add_literal(migraphx::literal(s_ind, d_ind));
migraphx::shape s_ind{migraphx::shape::int32_type, {4, 1, 4, 4}};

migraphx::shape s8{migraphx::shape::float_type, {8, 1, 4, 4}};
std::vector<float> d8 = {
0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0,
0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0,
0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0,
0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0,
0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0,
0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0,
0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0,
0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0,
0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0,
0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0,
0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0};
auto l8 = mm->add_literal(migraphx::literal(s8, d8));

migraphx::shape s4{migraphx::shape::float_type, {4, 1, 4, 4}};
std::vector<float> d4 = {
0, 0, 0, 0, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3,
2.0f / 3, 2.0f / 3, 2.0f / 3, 2.0f / 3, 0, 0, 0, 0,
0, 0, 0, 0, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3,
2.0f / 3, 2.0f / 3, 2.0f / 3, 2.0f / 3, 0, 0, 0, 0,
0, 0, 0, 0, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3,
2.0f / 3, 2.0f / 3, 2.0f / 3, 2.0f / 3, 0, 0, 0, 0,
0, 0, 0, 0, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3,
2.0f / 3, 2.0f / 3, 2.0f / 3, 2.0f / 3, 0, 0, 0, 0};
auto l4 = mm->add_literal(migraphx::literal(s4, d4));
std::vector<int> d_ind = {0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 2, 2, 2, 3, 0, 0, 0, 1, 2, 2,
2, 3, 2, 2, 2, 3, 2, 2, 2, 3, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1,
2, 3, 3, 3, 0, 1, 1, 1, 2, 3, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3};

auto l_ind = mm->add_literal(migraphx::literal(s_ind, d_ind));

migraphx::shape s2{migraphx::shape::float_type, {2, 1, 4, 4}};
std::vector<float> d2(32, 0);

std::vector<float> d2 = {-0.25, 0.25, 0.75, 0.25, -0.25, 0.25, 0.75, 0.25,
-0.25, 0.25, 0.75, 0.25, -0.25, 0.25, 0.75, 0.25,
-0.25, 0.25, 0.75, 0.25, -0.25, 0.25, 0.75, 0.25,
-0.25, 0.25, 0.75, 0.25, -0.25, 0.25, 0.75, 0.25};

auto l2 = mm->add_literal(migraphx::literal(s2, d2));

migraphx::shape s1{migraphx::shape::float_type, {1, 1, 4, 4}};
std::vector<float> d1(16, 0.0f);

std::vector<float> d1 = {-0.25,
-0.25,
-0.25,
-0.25,
0.25,
0.25,
0.25,
0.25,
0.75,
0.75,
0.75,
0.75,
0.25,
0.25,
0.25,
0.25};

auto l1 = mm->add_literal(migraphx::literal(s1, d1));

mm->add_instruction(migraphx::make_op("undefined"));
auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), x);
auto data = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp, l_ind);
auto slc80 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {8}}}), data);
auto slc81 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {8}}, {"ends", {16}}}), data);
auto diff8 = mm->add_instruction(migraphx::make_op("sub"), slc81, slc80);
auto mul8 = mm->add_instruction(migraphx::make_op("mul"), diff8, l8);
auto add8 = mm->add_instruction(migraphx::make_op("add"), mul8, slc80);
auto slc40 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {4}}}), add8);
auto slc41 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {4}}, {"ends", {8}}}), add8);
auto diff4 = mm->add_instruction(migraphx::make_op("sub"), slc41, slc40);
auto mul4 = mm->add_instruction(migraphx::make_op("mul"), diff4, l4);
auto add4 = mm->add_instruction(migraphx::make_op("add"), mul4, slc40);
auto slc20 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2}}}), add4);
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2}}}), data);
auto slc21 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {4}}}), add4);
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {4}}}), data);
auto diff2 = mm->add_instruction(migraphx::make_op("sub"), slc21, slc20);
auto mul2 = mm->add_instruction(migraphx::make_op("mul"), diff2, l2);
auto add2 = mm->add_instruction(migraphx::make_op("add"), mul2, slc20);
Expand Down
Loading
Loading