Skip to content

Commit

Permalink
Add unpack_int4 kernel (#3528)
Browse files Browse the repository at this point in the history
  • Loading branch information
music-dino authored Oct 18, 2024
1 parent 9e8c23e commit ea8ec1c
Show file tree
Hide file tree
Showing 6 changed files with 273 additions and 68 deletions.
90 changes: 90 additions & 0 deletions src/targets/gpu/jit/unpack_int4.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "migraphx/instruction.hpp"
#include "migraphx/instruction_ref.hpp"
#include <migraphx/reduce_dims.hpp>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {

using namespace migraphx::gpu::gen; // NOLINT

static const char* const unpack_int4_kernel = R"__migraphx__(
#include <migraphx/kernels/unpack_int4.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
MIGRAPHX_GLOBAL void ${kernel}(${params})
{
transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) {
unpack_int4<${axis}>(xs...);
});
}
}
} // namespace migraphx
)__migraphx__";

struct unpack_int4_compiler : compiler<unpack_int4_compiler>
{
std::vector<std::string> names() const { return {"unpack_int4"}; }

operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
options.inputs = inputs;
options.output = inputs.back();
options.virtual_inputs = reduce_dims(normalize_permutation(options.inputs));
options.kernel_name = "unpack_int4_kernel";
options.set_launch_params(v, compute_global_for(ctx, inputs.front().elements()));

auto src =
interpolate_string(unpack_int4_kernel,
{{"kernel", options.kernel_name},
{"params", enum_params(options.inputs.size(), "void * private_p")},
{"args", enum_params(options.inputs.size(), "private_p")},
{"axis", std::to_string(v.at("axis").to<int>())}});
return compile_hip_code_object(src, options);
}

compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return compile_op(ctx, to_shapes(ins->inputs()), op.to_value());
}
};

} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
57 changes: 57 additions & 0 deletions src/targets/gpu/kernels/include/migraphx/kernels/unpack_int4.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_UNPACK_INT4_HPP
#define MIGRAPHX_GUARD_KERNELS_UNPACK_INT4_HPP

#include "migraphx/kernels/types.hpp"
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/tensor_view.hpp>

namespace migraphx {

template <int Axis, class Output, class Input>
__device__ void unpack_int4(Output output, Input input)
{
const auto input_shape = input.get_shape();

make_index().global_stride(input_shape.elements(), [&](auto i) {
auto idx = input_shape.multi(i);
idx[Axis] *= 2;
const auto input_val = input[i];

// unpack_int4 op's normalize_compute_shape will ensure that Input::type is either uint8_t
// or int8_t
if constexpr(is_unsigned<typename Input::type>{})
output[idx] = input_val & 0xfu;
else
// NOLINTNEXTLINE (hicpp-signed-bitwise)
output[idx] = static_cast<int8_t>(static_cast<uint8_t>(input_val) << 4) >> 4;

idx[Axis] += 1;
output[idx] = input_val >> 4;
});
}

} // namespace migraphx
#endif
21 changes: 0 additions & 21 deletions src/targets/gpu/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ struct miopen_apply
add_reshape_lazy_op();
add_group_query_attention_op();
add_scan_slice_op();
add_unpack_int4_op();
}

void copy_params() const
Expand Down Expand Up @@ -584,26 +583,6 @@ struct miopen_apply
ins, mod->insert_instruction(ins, ins->get_operator(), inputs));
});
}

void add_unpack_int4_op()
{
apply_map.emplace("unpack_int4", [=](instruction_ref ins) {
auto inputs = ins->inputs();
auto output = insert_allocation(ins, ins->get_shape());
std::vector<instruction_ref> cpu_inputs;
auto gpu_inputs = ins->inputs();
std::transform(
gpu_inputs.begin(), gpu_inputs.end(), std::back_inserter(cpu_inputs), [&](auto in) {
return mod->insert_instruction(ins, make_op("hip::copy_from_gpu"), in);
});
cpu_inputs.front() =
mod->insert_instruction(ins, make_op("hip::sync_stream"), cpu_inputs);
auto cpu_out = mod->insert_instruction(ins, ins->get_operator(), cpu_inputs);
auto gpu_out =
mod->insert_instruction(ins, make_op("hip::copy_to_gpu"), cpu_out, output);
return mod->replace_instruction(ins, gpu_out);
});
}
};

void lowering::apply(module_pass_manager& mpm) const
Expand Down
20 changes: 20 additions & 0 deletions test/ref/unpack_int4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,23 @@ TEST_CASE(unpack_int4_nchw)
0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}

TEST_CASE(unpack_int4_signed)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int8_type, {2, 2}};
auto l0 = mm->add_literal(migraphx::literal{
s,
{0b1000'0111 /*-8'7*/, 0b0001'1111 /*1'-1*/, 0b1110'1001 /*-2'-7*/, 0b0101'0111 /*5'7*/}});
mm->add_instruction(migraphx::make_op("unpack_int4"), l0);

p.compile(migraphx::make_target("ref"));

auto result = p.eval({}).back();
std::vector<int8_t> results_vector(s.elements());
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });

std::vector<int8_t> gold{7, -8, -1, 1, -7, -2, 7, 5};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
105 changes: 58 additions & 47 deletions test/verify/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,56 +67,67 @@ int main(int argc, const char* argv[])
{
run_verify rv;
rv.add_validation_for("gpu", &validate_gpu);
rv.disable_test_for("cpu", {
"test_if_lp", "test_if_param", "test_if_literal", "test_select_module_add",
"test_select_module_reduce", "test_select_module_conv", "test_split_single_dyn_dim",
"test_resize_dyn", "test_instancenorm_large_3d<migraphx::shape::float_type>",
"test_instancenorm_large_3d<migraphx::shape::half_type>",
// these tests are disabled due issue of lossy downcast, see issue#2517
rv.disable_test_for(
"cpu",
{"test_if_lp",
"test_if_param",
"test_if_literal",
"test_select_module_add",
"test_select_module_reduce",
"test_select_module_conv",
"test_split_single_dyn_dim",
"test_resize_dyn",
"test_instancenorm_large_3d<migraphx::shape::float_type>",
"test_instancenorm_large_3d<migraphx::shape::half_type>",
// these tests are disabled due issue of lossy downcast, see issue#2517
#if defined(__GNUC__) and !defined(__clang__)
"test_batch_quant_dot_1<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8, true>, "
"float>",
"test_quant_dot_3args_4<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8, true>, "
"float>",
"test_quant_dot_3args_5<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8, true>, "
"float>",
"test_batch_quant_dot_1<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8, false>, "
"float>",
"test_quant_dot_3args_4<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8, false>, "
"float>",
"test_quant_dot_3args_5<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8, false>, "
"float>",
"test_batch_quant_dot_1<migraphx::fp8::float8<migraphx::fp8::f8_type::bf8, false>, "
"float>",
"test_quant_dot_3args_4<migraphx::fp8::float8<migraphx::fp8::f8_type::bf8, false>, "
"float>",
"test_quant_dot_3args_5<migraphx::fp8::float8<migraphx::fp8::f8_type::bf8, false>, "
"float>",
"test_batch_quant_dot_1<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8, true>, "
"float>",
"test_quant_dot_3args_4<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8, true>, "
"float>",
"test_quant_dot_3args_5<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8, true>, "
"float>",
"test_batch_quant_dot_1<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8, false>, "
"float>",
"test_quant_dot_3args_4<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8, false>, "
"float>",
"test_quant_dot_3args_5<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8, false>, "
"float>",
"test_batch_quant_dot_1<migraphx::fp8::float8<migraphx::fp8::f8_type::bf8, false>, "
"float>",
"test_quant_dot_3args_4<migraphx::fp8::float8<migraphx::fp8::f8_type::bf8, false>, "
"float>",
"test_quant_dot_3args_5<migraphx::fp8::float8<migraphx::fp8::f8_type::bf8, false>, "
"float>",
#else
"test_batch_quant_dot_1<migraphx::fp8::fp8e4m3fnuz, float>",
"test_quant_dot_3args_4<migraphx::fp8::fp8e4m3fnuz, float>",
"test_quant_dot_3args_5<migraphx::fp8::fp8e4m3fnuz, float>",
"test_batch_quant_dot_1<migraphx::fp8::fp8e4m3fn, float>",
"test_quant_dot_3args_4<migraphx::fp8::fp8e4m3fn, float>",
"test_quant_dot_3args_5<migraphx::fp8::fp8e4m3fn, float>",
"test_batch_quant_dot_1<migraphx::fp8::fp8e5m2, float>",
"test_quant_dot_3args_4<migraphx::fp8::fp8e5m2, float>",
"test_quant_dot_3args_5<migraphx::fp8::fp8e5m2, float>",
"test_batch_quant_dot_1<migraphx::fp8::fp8e4m3fnuz, float>",
"test_quant_dot_3args_4<migraphx::fp8::fp8e4m3fnuz, float>",
"test_quant_dot_3args_5<migraphx::fp8::fp8e4m3fnuz, float>",
"test_batch_quant_dot_1<migraphx::fp8::fp8e4m3fn, float>",
"test_quant_dot_3args_4<migraphx::fp8::fp8e4m3fn, float>",
"test_quant_dot_3args_5<migraphx::fp8::fp8e4m3fn, float>",
"test_batch_quant_dot_1<migraphx::fp8::fp8e5m2, float>",
"test_quant_dot_3args_4<migraphx::fp8::fp8e5m2, float>",
"test_quant_dot_3args_5<migraphx::fp8::fp8e5m2, float>",
#endif
"test_block_reduce_small<3, migraphx::shape::int8_type>",
"test_block_reduce_small<4, migraphx::shape::int8_type>",
"test_block_reduce_small<8, migraphx::shape::int8_type>",
"test_block_reduce_small<16, migraphx::shape::int8_type>",
"test_block_reduce_small<25, migraphx::shape::int8_type>",
"test_block_reduce_small<32, migraphx::shape::int8_type>",
"test_block_reduce_small<64, migraphx::shape::int8_type>",
"test_block_reduce_small<67, migraphx::shape::int8_type>",
"test_block_reduce_small<128, migraphx::shape::int8_type>",
"test_block_reduce_small<129, migraphx::shape::int8_type>",
// disabled because CPU does eliminate_data_type to float for everything
"test_bitwise_and<migraphx::shape::int32_type>",
"test_bitwise_and<migraphx::shape::uint8_type>",
});
"test_block_reduce_small<3, migraphx::shape::int8_type>",
"test_block_reduce_small<4, migraphx::shape::int8_type>",
"test_block_reduce_small<8, migraphx::shape::int8_type>",
"test_block_reduce_small<16, migraphx::shape::int8_type>",
"test_block_reduce_small<25, migraphx::shape::int8_type>",
"test_block_reduce_small<32, migraphx::shape::int8_type>",
"test_block_reduce_small<64, migraphx::shape::int8_type>",
"test_block_reduce_small<67, migraphx::shape::int8_type>",
"test_block_reduce_small<128, migraphx::shape::int8_type>",
"test_block_reduce_small<129, migraphx::shape::int8_type>",
// disabled because CPU does eliminate_data_type to float for everything
"test_bitwise_and<migraphx::shape::int32_type>",
"test_bitwise_and<migraphx::shape::uint8_type>",

"test_unpack_int4<migraphx::shape::uint8_type>",
"test_unpack_int4<migraphx::shape::int8_type>",
"test_unpack_int4<migraphx::shape::uint8_type, 0>",
"test_unpack_int4<migraphx::shape::int8_type, 0>"});
rv.disable_test_for("gpu",
{
// These passes on MI300 but fails on others, same issue as CPU.
Expand Down
48 changes: 48 additions & 0 deletions test/verify/test_unpack_int4.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>

template <migraphx::shape::type_t T, int Axis = -1>
struct test_unpack_int4 : verify_program<test_unpack_int4<T, Axis>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();

auto x = mm->add_parameter("x", migraphx::shape{T, {64, 32}});
mm->add_instruction(migraphx::make_op("unpack_int4", {{"axis", Axis}}), x);

return p;
}
};

template struct test_unpack_int4<migraphx::shape::uint8_type>;
template struct test_unpack_int4<migraphx::shape::int8_type>;
template struct test_unpack_int4<migraphx::shape::uint8_type, 0>;
template struct test_unpack_int4<migraphx::shape::int8_type, 0>;

0 comments on commit ea8ec1c

Please sign in to comment.