From c9bbea151de896d46ae3a64fc7cc296d5ef8311a Mon Sep 17 00:00:00 2001 From: turneram <71655887+turneram@users.noreply.github.com> Date: Fri, 23 Aug 2024 08:42:50 -0500 Subject: [PATCH] Add ONNX parsing for SkipSimplifiedLayerNormalization (#3140) --- ...se_skip_simplified_layer_normalization.cpp | 139 ++++++++++ test/onnx/gen_onnx.py | 82 ++++++ ...layer_normalization_invalid_input_test.cpp | 32 +++ ...ayer_normalization_invalid_n_args_test.cpp | 32 +++ ...ip_simplified_layer_normalization_test.cpp | 35 +++ ...plified_layer_normalization_bias_test.onnx | 51 ++++ ...ayer_normalization_invalid_input_test.onnx | 31 +++ ...yer_normalization_invalid_n_args_test.onnx | 23 ++ ...p_simplified_layer_normalization_test.onnx | 45 ++++ .../skip_simplified_layer_normalization.cpp | 245 ++++++++++++++++++ 10 files changed, 715 insertions(+) create mode 100644 src/onnx/parse_skip_simplified_layer_normalization.cpp create mode 100644 test/onnx/parse/skip_simplified_layer_normalization_invalid_input_test.cpp create mode 100644 test/onnx/parse/skip_simplified_layer_normalization_invalid_n_args_test.cpp create mode 100644 test/onnx/parse/skip_simplified_layer_normalization_test.cpp create mode 100644 test/onnx/skip_simplified_layer_normalization_bias_test.onnx create mode 100644 test/onnx/skip_simplified_layer_normalization_invalid_input_test.onnx create mode 100644 test/onnx/skip_simplified_layer_normalization_invalid_n_args_test.onnx create mode 100644 test/onnx/skip_simplified_layer_normalization_test.onnx create mode 100644 test/onnx/verify/skip_simplified_layer_normalization.cpp diff --git a/src/onnx/parse_skip_simplified_layer_normalization.cpp b/src/onnx/parse_skip_simplified_layer_normalization.cpp new file mode 100644 index 00000000000..0e61cce3be3 --- /dev/null +++ b/src/onnx/parse_skip_simplified_layer_normalization.cpp @@ -0,0 +1,139 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +// com.microsoft.SkipSimplifiedLayerNormalization +// Skip and Root Mean Square Layer Normalization + +// Version +// This version of the operator has been available since version 1 of the 'com.microsoft' operator +// set. + +// Type Constraints +// T : tensor(float), tensor(float16) +// Constrain input and output types to float or half tensors. +// U : tensor(float) +// Constrain mean and inv_std_var to float tensors. + +struct parse_skip_simplified_layer_normalization + : op_parser +{ + std::vector operators() const { return {{"SkipSimplifiedLayerNormalization"}}; } + + std::vector parse(const op_desc& /*opd*/, + const onnx_parser& parser, + const onnx_parser::node_info& info, + std::vector args) const + { + // Attributes + // epsilon : float + // The epsilon value to use to avoid division by zero. + float epsilon = 1e-5f; + if(contains(info.attributes, "epsilon")) + { + epsilon = parser.parse_value(info.attributes.at("epsilon")).at(); + } + + // Inputs (3 - 4) + // input : T + // 3D input tensor with shape (batch_size, sequence_length, hidden_size) Or 2D input tensor + // with shape (token_count, hidden_size) + // skip : T + // 3D input tensor with shape (batch_size, sequence_length, hidden_size) + // Or 2D input tensor with shape (token_count, hidden_size) + // gamma : T + // 1D input tensor with shape (hidden_size) + // bias (optional) : T + // 1D bias tensor with shape (hidden_size) - not used by ORT + + if(args.size() < 3 or args.size() > 4) + { + MIGRAPHX_THROW("PARSE_SKIPSIMPLIFIEDLAYERNORMALIZATION: invalid input count"); + } + + auto x = args.at(0); + auto skip = args.at(1); + auto gamma = args.at(2); + instruction_ref bias; + if(args.size() == 4) + { + bias = args.at(3); + } + + auto x_shape = x->get_shape(); + auto x_dtype = x_shape.type(); + int64_t x_rank = x_shape.ndim(); + int64_t skip_rank = skip->get_shape().ndim(); + int64_t gamma_rank = gamma->get_shape().ndim(); + // axis = hidden_size dim + int64_t axis = x_rank - 1; + + if(x_rank < 2 or x_rank > 3 or x_rank != skip_rank or gamma_rank != 1) + { + MIGRAPHX_THROW("PARSE_SKIPSIMPLIFIEDLAYERNORMALIZATION: invalid input shape"); + } + + x = info.add_common_op("add", x, skip); + auto x_sq = info.add_common_op("mul", x, x); + auto rms = info.add_instruction(make_op("reduce_mean", {{"axes", {axis}}}), x_sq); + auto mean = rms; + epsilon = + (x_dtype == migraphx::shape::half_type and std::abs(epsilon) < 1e-7) ? 1e-7 : epsilon; + auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_dtype}, {epsilon}}); + rms = info.add_common_op("add", rms, eps); + auto rrms = info.add_instruction(make_op("rsqrt"), rms); + auto result = info.add_common_op("mul", x, rrms); + result = info.add_common_op("mul", result, gamma); + if(args.size() == 4) + { + result = info.add_common_op("add", result, bias); + x = info.add_common_op("add", x, bias); + } + + // Outputs (1 - 4) + // output : T + // 3D output tensor with shape (batch_size, sequence_length, hidden_size)Or 2D output tensor + // with shape (token_count, hidden_size) + // mean (optional) : U Saved mean used during training + // to speed up gradient computation + // inv_std_var (optional) : U Saved inverse standard + // variance used during training to speed up gradient computation. + // input_skip_bias_sum (optional) : T Sum of the input and skip inputs (and bias if it + // exists)with shape (batch_size, sequence_length, hidden_size) or (token_count, + // hidden_size). + + return {result, mean, rrms, x}; + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index 3efc787c559..d0224091d23 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -10706,6 +10706,88 @@ def size_verify_test(): return ([node], [x], [y]) +@onnx_test() +def skip_simplified_layer_normalization_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [2, 2, 4]) + skip = helper.make_tensor_value_info('skip', TensorProto.FLOAT16, + [2, 2, 4]) + gamma = helper.make_tensor_value_info('gamma', TensorProto.FLOAT16, [4]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [2, 2, 4]) + mean = helper.make_tensor_value_info('mean', TensorProto.FLOAT, [2, 2, 1]) + inv_std_var = helper.make_tensor_value_info('inv_std_var', + TensorProto.FLOAT, [2, 2, 1]) + input_skip_bias_sum = helper.make_tensor_value_info( + 'input_skip_bias_sum', TensorProto.FLOAT16, [2, 2, 4]) + + node = onnx.helper.make_node( + 'SkipSimplifiedLayerNormalization', + inputs=['x', 'skip', 'gamma'], + outputs=['y', 'mean', 'inv_std_var', 'input_skip_bias_sum'], + epsilon=1e-5, + domain="com.microsoft") + + return ([node], [x, skip, + gamma], [y, mean, inv_std_var, input_skip_bias_sum]) + + +@onnx_test() +def skip_simplified_layer_normalization_bias_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [2, 2, 4]) + skip = helper.make_tensor_value_info('skip', TensorProto.FLOAT16, + [2, 2, 4]) + gamma = helper.make_tensor_value_info('gamma', TensorProto.FLOAT16, [4]) + bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT16, [4]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [2, 2, 4]) + mean = helper.make_tensor_value_info('mean', TensorProto.FLOAT, [2, 2, 1]) + inv_std_var = helper.make_tensor_value_info('inv_std_var', + TensorProto.FLOAT, [2, 2, 1]) + input_skip_bias_sum = helper.make_tensor_value_info( + 'input_skip_bias_sum', TensorProto.FLOAT16, [2, 2, 4]) + + node = onnx.helper.make_node( + 'SkipSimplifiedLayerNormalization', + inputs=['x', 'skip', 'gamma', 'bias'], + outputs=['y', 'mean', 'inv_std_var', 'input_skip_bias_sum'], + epsilon=1e-5, + domain="com.microsoft") + + return ([node], [x, skip, gamma, + bias], [y, mean, inv_std_var, input_skip_bias_sum]) + + +@onnx_test() +def skip_simplified_layer_normalization_invalid_n_args_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [2, 2, 4]) + skip = helper.make_tensor_value_info('skip', TensorProto.FLOAT16, + [2, 2, 4]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [2, 2, 4]) + + node = onnx.helper.make_node('SkipSimplifiedLayerNormalization', + inputs=['x', 'skip'], + outputs=['y'], + epsilon=1e-5, + domain="com.microsoft") + + return ([node], [x, skip], [y]) + + +@onnx_test() +def skip_simplified_layer_normalization_invalid_input_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [2, 2, 2, 4]) + skip = helper.make_tensor_value_info('skip', TensorProto.FLOAT16, + [2, 2, 4]) + gamma = helper.make_tensor_value_info('gamma', TensorProto.FLOAT16, [2, 4]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [2, 2, 2, 4]) + + node = onnx.helper.make_node('SkipSimplifiedLayerNormalization', + inputs=['x', 'skip', 'gamma'], + outputs=['y'], + epsilon=1e-5, + domain="com.microsoft") + + return ([node], [x, skip, gamma], [y]) + + @onnx_test() def slice_test(): x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3, 2]) diff --git a/test/onnx/parse/skip_simplified_layer_normalization_invalid_input_test.cpp b/test/onnx/parse/skip_simplified_layer_normalization_invalid_input_test.cpp new file mode 100644 index 00000000000..5c0495f8c18 --- /dev/null +++ b/test/onnx/parse/skip_simplified_layer_normalization_invalid_input_test.cpp @@ -0,0 +1,32 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(skip_simplified_layer_normalization_invalid_input_test) +{ + EXPECT(test::throws([&] { + migraphx::parse_onnx("skip_simplified_layer_normalization_invalid_input_test.onnx"); + })); +} diff --git a/test/onnx/parse/skip_simplified_layer_normalization_invalid_n_args_test.cpp b/test/onnx/parse/skip_simplified_layer_normalization_invalid_n_args_test.cpp new file mode 100644 index 00000000000..4faa3b16833 --- /dev/null +++ b/test/onnx/parse/skip_simplified_layer_normalization_invalid_n_args_test.cpp @@ -0,0 +1,32 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(skip_simplified_layer_normalization_invalid_n_args_test) +{ + EXPECT(test::throws([&] { + migraphx::parse_onnx("skip_simplified_layer_normalization_invalid_n_args_test.onnx"); + })); +} diff --git a/test/onnx/parse/skip_simplified_layer_normalization_test.cpp b/test/onnx/parse/skip_simplified_layer_normalization_test.cpp new file mode 100644 index 00000000000..1e0b33f8302 --- /dev/null +++ b/test/onnx/parse/skip_simplified_layer_normalization_test.cpp @@ -0,0 +1,35 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include + +TEST_CASE(skip_simplified_layer_normalization_test) +{ + migraphx::program p = make_simplified_layer_norm( + {2, 2, 4}, {2, 2, 4}, {4}, -1, 1e-5f, migraphx::shape::half_type); + + auto prog = optimize_onnx("skip_simplified_layer_normalization_test.onnx"); + EXPECT(p == prog); +} diff --git a/test/onnx/skip_simplified_layer_normalization_bias_test.onnx b/test/onnx/skip_simplified_layer_normalization_bias_test.onnx new file mode 100644 index 00000000000..83b5a9a5db3 --- /dev/null +++ b/test/onnx/skip_simplified_layer_normalization_bias_test.onnx @@ -0,0 +1,51 @@ + -skip_simplified_layer_normalization_bias_test:˜ +… +x +skip +gamma +biasymean inv_std_varinput_skip_bias_sum" SkipSimplifiedLayerNormalization* +epsilon¬Å'7 : com.microsoft-skip_simplified_layer_normalization_bias_testZ +x + + + + +Z +skip + + + + +Z +gamma + + + +Z +bias + + + +b +y + + + + +b +mean + + + +b! + inv_std_var + + + +b) +input_skip_bias_sum + + + + +B \ No newline at end of file diff --git a/test/onnx/skip_simplified_layer_normalization_invalid_input_test.onnx b/test/onnx/skip_simplified_layer_normalization_invalid_input_test.onnx new file mode 100644 index 00000000000..8e1dd5fdccd --- /dev/null +++ b/test/onnx/skip_simplified_layer_normalization_invalid_input_test.onnx @@ -0,0 +1,31 @@ + 6skip_simplified_layer_normalization_invalid_input_test:€ +W +x +skip +gammay" SkipSimplifiedLayerNormalization* +epsilon¬Å'7 : com.microsoft6skip_simplified_layer_normalization_invalid_input_testZ +x + + + + + +Z +skip + + + + +Z +gamma +  + + +b +y + + + + + +B \ No newline at end of file diff --git a/test/onnx/skip_simplified_layer_normalization_invalid_n_args_test.onnx b/test/onnx/skip_simplified_layer_normalization_invalid_n_args_test.onnx new file mode 100644 index 00000000000..9ee41b486c0 --- /dev/null +++ b/test/onnx/skip_simplified_layer_normalization_invalid_n_args_test.onnx @@ -0,0 +1,23 @@ + 7skip_simplified_layer_normalization_invalid_n_args_test:Ù +P +x +skipy" SkipSimplifiedLayerNormalization* +epsilon¬Å'7 : com.microsoft7skip_simplified_layer_normalization_invalid_n_args_testZ +x + + + + +Z +skip + + + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/skip_simplified_layer_normalization_test.onnx b/test/onnx/skip_simplified_layer_normalization_test.onnx new file mode 100644 index 00000000000..12501944a0f --- /dev/null +++ b/test/onnx/skip_simplified_layer_normalization_test.onnx @@ -0,0 +1,45 @@ + (skip_simplified_layer_normalization_test:ø + +x +skip +gammaymean inv_std_varinput_skip_bias_sum" SkipSimplifiedLayerNormalization* +epsilon¬Å'7 : com.microsoft(skip_simplified_layer_normalization_testZ +x + + + + +Z +skip + + + + +Z +gamma + + + +b +y + + + + +b +mean + + + +b! + inv_std_var + + + +b) +input_skip_bias_sum + + + + +B \ No newline at end of file diff --git a/test/onnx/verify/skip_simplified_layer_normalization.cpp b/test/onnx/verify/skip_simplified_layer_normalization.cpp new file mode 100644 index 00000000000..fd19abf7350 --- /dev/null +++ b/test/onnx/verify/skip_simplified_layer_normalization.cpp @@ -0,0 +1,245 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include +#include + +TEST_CASE(skip_simplified_layer_normalization_test) +{ + using migraphx::half; + std::vector x{half{0.8}, + half{-0.5}, + half{0.0}, + half{1.0}, + half{0.5}, + half{0.2}, + half{0.3}, + half{-0.6}, + half{10.0}, + half{-1.0}, + half{0.0}, + half{1.0}, + half{1.2}, + half{3.2}, + half{-4.1}, + half{5.3}}; + std::vector skip{half{1.2}, + half{-1.0}, + half{2.0}, + half{1.0}, + half{1.5}, + half{2.2}, + half{-3.3}, + half{2.6}, + half{1.0}, + half{-10.0}, + half{-1.0}, + half{2.0}, + half{-1.2}, + half{1.6}, + half{-1.1}, + half{1.3}}; + std::vector scale{half{0.1}, half{0.2}, half{4.0}, half{-2.2}}; + + auto p = read_onnx("skip_simplified_layer_normalization_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape s_x{migraphx::shape::half_type, {2, 2, 4}}; + migraphx::shape s_s{migraphx::shape::half_type, {4}}; + + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(s_x, x.data()); + pp["skip"] = migraphx::argument(s_x, skip.data()); + pp["gamma"] = migraphx::argument(s_s, scale.data()); + + auto results = p.eval(pp); + auto output = results.at(0); + auto mean = results.at(1); + auto inv_std_var = results.at(2); + auto input_skip_bias_sum = results.at(3); + + std::vector result_vector; + std::vector mean_vector; + std::vector inv_std_var_vector; + std::vector input_skip_bias_sum_vector; + + output.visit([&](auto vals) { result_vector.assign(vals.begin(), vals.end()); }); + mean.visit([&](auto vals) { mean_vector.assign(vals.begin(), vals.end()); }); + inv_std_var.visit([&](auto vals) { inv_std_var_vector.assign(vals.begin(), vals.end()); }); + input_skip_bias_sum.visit( + [&](auto vals) { input_skip_bias_sum_vector.assign(vals.begin(), vals.end()); }); + + std::vector gold = {half{0.10596}, + half{-0.1589}, + half{4.24}, + half{-2.33}, + half{0.0838}, + half{0.2012}, + half{-5.03}, + half{-1.844}, + half{0.1385}, + half{-0.277}, + half{-0.504}, + half{-0.831}, + half{0.0}, + half{0.1984}, + half{-4.3}, + half{-3.0}}; + std::vector gold_mean{half{3.5625}, half{5.6875}, half{63.0}, half{23.421875}}; + std::vector gold_inv_std_var{half{0.5298}, half{0.4194}, half{0.1260}, half{0.2067}}; + std::vector gold_input_skip_bias_sum = {half{2.0}, + half{-1.5}, + half{2.0}, + half{2.0}, + half{2.0}, + half{2.3984375}, + half{-3.0}, + half{2.0}, + half{11.0}, + half{-11.0}, + half{-1.0}, + half{3.0}, + half{0.0}, + half{4.796875}, + half{-5.203125}, + half{6.6015625}}; + + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); + EXPECT(migraphx::verify::verify_rms_range(mean_vector, gold_mean)); + EXPECT(migraphx::verify::verify_rms_range(inv_std_var_vector, gold_inv_std_var)); + EXPECT( + migraphx::verify::verify_rms_range(input_skip_bias_sum_vector, gold_input_skip_bias_sum)); +} + +TEST_CASE(skip_simplified_layer_normalization_bias_test) +{ + using migraphx::half; + std::vector x{half{0.8}, + half{-0.5}, + half{0.0}, + half{1.0}, + half{0.5}, + half{0.2}, + half{0.3}, + half{-0.6}, + half{10.0}, + half{-1.0}, + half{0.0}, + half{1.0}, + half{1.2}, + half{3.2}, + half{-4.1}, + half{5.3}}; + std::vector skip{half{1.2}, + half{-1.0}, + half{2.0}, + half{1.0}, + half{1.5}, + half{2.2}, + half{-3.3}, + half{2.6}, + half{1.0}, + half{-10.0}, + half{-1.0}, + half{2.0}, + half{-1.2}, + half{1.6}, + half{-1.1}, + half{1.3}}; + std::vector scale{half{0.1}, half{0.2}, half{4.0}, half{-2.2}}; + std::vector bias{half{1.0}, half{1.0}, half{1.0}, half{1.0}}; + + auto p = read_onnx("skip_simplified_layer_normalization_bias_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape s_x{migraphx::shape::half_type, {2, 2, 4}}; + migraphx::shape s_s{migraphx::shape::half_type, {4}}; + + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(s_x, x.data()); + pp["skip"] = migraphx::argument(s_x, skip.data()); + pp["gamma"] = migraphx::argument(s_s, scale.data()); + pp["bias"] = migraphx::argument(s_s, bias.data()); + + auto results = p.eval(pp); + auto output = results.at(0); + auto mean = results.at(1); + auto inv_std_var = results.at(2); + auto input_skip_bias_sum = results.at(3); + + std::vector result_vector; + std::vector mean_vector; + std::vector inv_std_var_vector; + std::vector input_skip_bias_sum_vector; + + output.visit([&](auto vals) { result_vector.assign(vals.begin(), vals.end()); }); + mean.visit([&](auto vals) { mean_vector.assign(vals.begin(), vals.end()); }); + inv_std_var.visit([&](auto vals) { inv_std_var_vector.assign(vals.begin(), vals.end()); }); + input_skip_bias_sum.visit( + [&](auto vals) { input_skip_bias_sum_vector.assign(vals.begin(), vals.end()); }); + + std::vector gold = {half{1.10596}, + half{0.8411}, + half{5.24}, + half{-1.33}, + half{1.0838}, + half{1.2012}, + half{-4.03}, + half{-0.844}, + half{1.1385}, + half{0.723}, + half{0.496}, + half{0.169}, + half{1.0}, + half{1.1984}, + half{-3.3}, + half{-2.0}}; + + std::vector gold_mean{half{3.5625}, half{5.6875}, half{63.0}, half{23.421875}}; + std::vector gold_inv_std_var{half{0.5298}, half{0.4194}, half{0.1260}, half{0.2067}}; + std::vector gold_input_skip_bias_sum = {half{3.0}, + half{-0.5}, + half{3.0}, + half{3.0}, + half{3.0}, + half{3.3984375}, + half{-2.0}, + half{3.0}, + half{12.0}, + half{-10.0}, + half{0.0}, + half{4.0}, + half{1.0}, + half{5.796875}, + half{-4.203125}, + half{7.6015625}}; + + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); + EXPECT(migraphx::verify::verify_rms_range(mean_vector, gold_mean)); + EXPECT(migraphx::verify::verify_rms_range(inv_std_var_vector, gold_inv_std_var)); + EXPECT( + migraphx::verify::verify_rms_range(input_skip_bias_sum_vector, gold_input_skip_bias_sum)); +}