Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
sunjiweiswift committed Aug 29, 2024
1 parent 90662a8 commit 4b56d54
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 98 deletions.
159 changes: 70 additions & 89 deletions include/subgroup/tile/impl/load_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,81 +108,86 @@ tile_load(tile_t& tile, payload_t& payload) {

using load_store_attr = load_store_attr_t<msg_type::block_2d, arch_tag>;

// static constexpr uint32_t max_load_width_in_elem = trans
// ? load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype)
// : load_store_attr::max_load_width_in_bytes / sizeof(dtype);
// static constexpr uint32_t max_load_width_in_elem = trans
// ? load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype)
// : load_store_attr::max_load_width_in_bytes / sizeof(dtype);
// static constexpr uint32_t max_load_height_in_elem = trans
// ? load_store_attr::max_trans_load_height_in_elem
// : load_store_attr::max_load_height_in_elem;
static constexpr uint32_t max_trans_load_width_in_elem =
load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype);
static constexpr uint32_t max_load_width_in_elem =
load_store_attr::max_load_width_in_bytes / sizeof(dtype);
// static constexpr uint32_t max_trans_load_width_in_elem =
// load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype);
// static constexpr uint32_t max_load_width_in_elem =
// load_store_attr::max_load_width_in_bytes / sizeof(dtype);

// static constexpr uint32_t max_trans_load_height_in_elem =
// load_store_attr::max_trans_load_height_in_elem;
static constexpr uint32_t max_load_height_in_elem =
load_store_attr::max_load_height_in_elem;

// static constexpr uint32_t max_load_height_in_elem =
// load_store_attr::max_load_height_in_elem;

static constexpr uint32_t elems_per_CL =
load_store_attr::cache_line_size_in_bytes / sizeof(dtype);

static constexpr uint32_t elems_per_reg =
register_bytes_t<arch_tag>::reg_in_bytes / sizeof(dtype);

static constexpr uint32_t ld_blk_size_y_limit =
mem_transpose ? max_trans_load_width_in_elem : max_load_height_in_elem;
static constexpr uint32_t ld_blk_size_y = reg_transpose
? block_size_y
: std::min(ld_blk_size_y_limit, block_size_y);
static constexpr uint32_t max_load_width_in_elem = trans
? load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype)
: load_store_attr::max_load_width_in_bytes / sizeof(dtype);

static constexpr uint32_t max_load_blk_height_in_elem = trans
? load_store_attr::max_trans_load_height_in_elem
: load_store_attr::max_load_height_in_elem;

static constexpr uint32_t ld_blk_width = std::min(
(mem_transpose ? block_size_y : block_size_x), max_load_width_in_elem);

static constexpr uint32_t ld_blk_height = std::min(
(mem_transpose ? block_size_x : block_size_y),
max_load_blk_height_in_elem);

static constexpr uint32_t ld_blk_size_y =
mem_transpose ? ld_blk_width : ld_blk_height;

static constexpr uint32_t ld_blk_size_y_limit = mem_transpose
? load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype)
: load_store_attr::max_load_height_in_elem;

// array len is used to make sure memory load is cache line aligned
// disabled while register or memory transpose
static constexpr uint8_t arr_len_candidate =
(reg_transpose ||
mem_transpose
((reg_transpose || mem_transpose)
// block elements should be integer
// times of register bytes
|| ((block_size_y * block_size_x) % elems_per_reg != 0)
|| ((block_elems) % elems_per_reg != 0)
// tail blocks also need to meet above condition
||
(((tile_size_y % block_size_y) * block_size_x) % elems_per_reg != 0)) ||
(block_size_y > ld_blk_size_y_limit)
|| (((tile_size_y % block_size_y) * block_size_x) % elems_per_reg != 0))
// || (block_size_y > load_store_attr::max_load_height_in_elem)
? 1
: (((tile_size_x % elems_per_CL) == 0)
? (((elems_per_CL % block_size_x) == 0)
? elems_per_CL / block_size_x
: 1)
: ((tile_size_x < elems_per_CL) ? (tile_size_x / block_size_x)
: 1));
static constexpr bool is_valid_arr_len_candidate = (arr_len_candidate == 1) ||
(arr_len_candidate == 2) || (arr_len_candidate == 4);

static constexpr uint8_t arr_len =
is_valid_arr_len_candidate ? arr_len_candidate : 1;

static_assert(
reg_transpose || mem_transpose ||
(!mem_transpose &&
(block_size_x * arr_len) <= max_load_width_in_elem),
"When reg_transpose was disabled, check 2d block width "
"restriction");
static_assert(
!reg_transpose ||
(!mem_transpose &&
(block_size_x * arr_len) <= max_trans_load_width_in_elem) ||
(mem_transpose && (block_size_y * arr_len) <= max_load_width_in_elem),
"When reg_transpose was enabled, check 2d block width "
"restriction");
static_assert(
!reg_transpose ||
(!mem_transpose && (block_size_y <= max_load_height_in_elem)) ||
(mem_transpose && (block_size_x) <= max_load_height_in_elem),
"When reg_transpose was enabled, check 2d block height "
"restriction");
static_assert(
tile_size_x % (block_size_x * arr_len) == 0,
"tile_size_x should be a multiple of (block_size_x * arr_len)");
// NBlocks must be {1,2,4} for bytes and words, {1,2} for dwords, 1 for
// qwords.
static constexpr bool arr_len =
((arr_len_candidate == 1) ||
(arr_len_candidate == 2 && sizeof(dtype) <= 4) ||
(arr_len_candidate == 4 && sizeof(dtype) <= 2))
? arr_len_candidate
: 1;

if constexpr (!trans && !mem_transform) {
static_assert(
(ld_blk_width * arr_len) <= max_load_width_in_elem,
"When Transposed and Transformed are both set to false, BlockWidth * NBlocks must not exceed 64 for bytes, 32 for words, 16 for dwords, and 8 for qwords");
} else if constexpr (mem_transform) {
static_assert(
(ld_blk_width * arr_len) <= max_load_width_in_elem,
"When Transformed is true then, BlockWidth * NBlocks must not exceed 64 for bytes and 32 for words.");
}
static_assert(
(reg_transpose &&
((block_size_x * sizeof(dtype)) % sizeof(load_dtype) == 0)) ||
Expand All @@ -198,10 +203,7 @@ tile_load(tile_t& tile, payload_t& payload) {
constexpr uint32_t load_block_elems = block_elems * arr_len;
auto reg_blk = tile.reg.xetla_select<load_block_elems, 1>(
(i * num_block_x + j) * block_elems);
constexpr uint32_t ld_blk_height = (reg_transpose && trans)
? detail::getNextPowerOf2<ld_blk_size_y>()
: ld_blk_size_y;
constexpr uint32_t tmp_size = ld_blk_height * block_size_x * arr_len;
constexpr uint32_t tmp_size = ld_blk_width * ld_blk_height * arr_len;
xetla_vector<dtype, tmp_size> reg_tmp;
#pragma unroll
for (uint32_t ii = 0; ii < block_size_y / ld_blk_size_y; ++ii) {
Expand All @@ -213,10 +215,8 @@ tile_load(tile_t& tile, payload_t& payload) {
mem_transpose ? offset_x : (offset_y + ii * ld_blk_size_y);
reg_tmp.xetla_format<native_type_t<load_dtype>>() = xetla_load_global<
native_type_t<load_dtype>,
(trans ? ld_blk_size_y : block_size_x) / scale_factor,
(trans ? block_size_x : ld_blk_size_y),
// block_size_x / scale_factor,
// ld_blk_size_y,
ld_blk_width / scale_factor,
ld_blk_height,
arr_len,
trans,
mem_transform,
Expand Down Expand Up @@ -261,11 +261,6 @@ tile_load(tile_t& tile, payload_t& payload) {
(mem_transpose ? remained_blk_size_y : block_size_x) / scale_factor;
constexpr uint8_t block_height =
mem_transpose ? block_size_x : remained_blk_size_y;
// constexpr uint32_t block_widthx_widthy_arrlen =
// (block_width - 1) | ((block_height - 1) << 8);
// gpu::xetla::detail::xetla_set_block_widthx_widthy_arrlen(
// tdesc.xetla_format<uint32_t>(), block_widthx_widthy_arrlen);

reg_blk.xetla_select<load_elems, 1>(remained_start)
.xetla_format<native_type_t<load_dtype>>() = xetla_load_global<
native_type_t<load_dtype>,
Expand All @@ -283,15 +278,6 @@ tile_load(tile_t& tile, payload_t& payload) {
payload.surface_pitch,
payload.offset_x + offset_x / scale_factor,
payload.offset_y + offset_y + remained_start_y);

// xetla_tload_global<
// load_dtype,
// (load_elems / scale_factor),
// L1,
// L2,
// trans,
// mem_transform,
// arch_tag>(tdesc);
}
}
}
Expand All @@ -304,24 +290,16 @@ tile_load(tile_t& tile, payload_t& payload) {
(!reg_transpose && (remained_size_y > ld_blk_size_y_limit))
? ld_blk_size_y_limit
: remained_size_y;
// auto payload_row = payload_2d.xetla_select<num_block_x, 1, 16, 1>(
// num_block_y * num_block_x, 0);
// detail::reset_tile_desc_core<
// num_block_x,
// block_size_x,
// remained_ld_blk_size_y,
// scale_factor,
// arr_len,
// mem_transpose>(payload_row);

#pragma unroll
for (uint32_t j = 0; j < num_block_x; j += arr_len) {
int32_t offset_x = j * block_size_x;
// xetla_tdescriptor tdesc = payload_row.row(j);
auto reg_blk = tile.reg.xetla_select<remained_block_elems * arr_len, 1>(
processed_elems + j * remained_block_elems);
constexpr uint32_t ld_blk_height = (reg_transpose && trans)
? detail::getNextPowerOf2<remained_ld_blk_size_y>()
: remained_ld_blk_size_y;
// constexpr uint32_t ld_blk_height = (reg_transpose && trans)
// ? detail::getNextPowerOf2<remained_ld_blk_size_y>()
// : remained_ld_blk_size_y;
constexpr uint32_t tmp_size = ld_blk_height * block_size_x * arr_len;
xetla_vector<dtype, tmp_size> reg_tmp;
#pragma unroll
Expand Down Expand Up @@ -490,7 +468,8 @@ tile_load(tile_t& tile, payload_t& payload) {

/// @brief This function loads data from unaligned-2D memory surface.
/// Loads an array of rectangular regions (X,Y)..(X+W,Y+H) from memory into
/// registers. Each block will be loaded serially by its corresponding payload.
/// registers. Each block will be loaded serially by its corresponding
/// payload.
/// @tparam tile_t Is the tile_t struct contains registers.
/// These registers will be the destination of load operation.
/// @tparam payload_t Is the mem_payload_t struct describing the memory
Expand Down Expand Up @@ -614,7 +593,8 @@ tile_load(tile_t& tile, payload_t& payload) {

/// @brief This function loads data from unaligned-2D memory surface.
/// Loads an array of rectangular regions (X,Y)..(X+W,Y+H) from memory into
/// registers. Each block will be loaded serially by its corresponding payload.
/// registers. Each block will be loaded serially by its corresponding
/// payload.
/// @tparam tile_t Is the tile_t struct contains registers.
/// These registers will be the destination of load operation.
/// @tparam payload_t Is the mem_payload_t struct describing the memory
Expand Down Expand Up @@ -679,7 +659,8 @@ tile_load(tile_t& tile, payload_t& payload) {

/// @brief This function loads data from unaligned-2D memory surface.
/// Loads an array of rectangular regions (X,Y)..(X+W,Y+H) from memory into
/// registers. Each block will be loaded serially by its corresponding payload.
/// registers. Each block will be loaded serially by its corresponding
/// payload.
/// @tparam tile_t Is the tile_t struct contains registers.
/// These registers will be the destination of load operation.
/// @tparam payload_t Is the mem_payload_t struct describing the memory
Expand Down Expand Up @@ -819,8 +800,8 @@ tile_load(
}

/// @brief Is the data load func from local shared memory to register file,
/// which supports the memory surface is 1d or 2d scenario. And we always assume
/// data in SLM is row major.
/// which supports the memory surface is 1d or 2d scenario. And we always
/// assume data in SLM is row major.
/// @tparam tile_t Is the tile_t struct contains registers
/// These registers will be the destination of load operation.
/// @tparam payload_t Is the mem_payload_t struct describing the memory
Expand Down Expand Up @@ -902,8 +883,8 @@ tile_load(tile_t& tile, payload_t& payload) {
}

/// @brief Is the data load func from shared local memory to register file,
/// which supports the memory surface is 1d scenario. And the src memory layout
/// is always row major.
/// which supports the memory surface is 1d scenario. And the src memory
/// layout is always row major.
/// @tparam tile_t Is the tile_t struct contains registers.
/// These registers will be the destination of load operation.
/// @tparam payload_t Is the mem_payload_t struct describing the memory
Expand Down
18 changes: 9 additions & 9 deletions tests/integration/gemm/fp32/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ TYPED_TEST_P(fp32_gemm_test, esimd) {

REGISTER_TYPED_TEST_SUITE_P(fp32_gemm_test, esimd);
using tests = ::testing::Types<
// Test1,
// Test2,
// Test3,
// Test4,
// Test5,
// Test6,
// Test7,
// Test8,
// Test9,
Test1,
Test2,
Test3,
Test4,
Test5,
Test6,
Test7,
Test8,
Test9,
Test10,
Test11>;
INSTANTIATE_TYPED_TEST_SUITE_P(fp32_gemm_test_suite, fp32_gemm_test, tests);

0 comments on commit 4b56d54

Please sign in to comment.