Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
sunjiweiswift committed Aug 29, 2024
1 parent 90662a8 commit a0b94f3
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 54 deletions.
81 changes: 36 additions & 45 deletions include/subgroup/tile/impl/load_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,27 +114,45 @@ tile_load(tile_t& tile, payload_t& payload) {
// static constexpr uint32_t max_load_height_in_elem = trans
// ? load_store_attr::max_trans_load_height_in_elem
// : load_store_attr::max_load_height_in_elem;
static constexpr uint32_t max_trans_load_width_in_elem =
load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype);
static constexpr uint32_t max_load_width_in_elem =
load_store_attr::max_load_width_in_bytes / sizeof(dtype);
// static constexpr uint32_t max_trans_load_width_in_elem =
// load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype);
// static constexpr uint32_t max_load_width_in_elem =
// load_store_attr::max_load_width_in_bytes / sizeof(dtype);

// static constexpr uint32_t max_trans_load_height_in_elem =
// load_store_attr::max_trans_load_height_in_elem;
static constexpr uint32_t max_load_height_in_elem =
load_store_attr::max_load_height_in_elem;

// static constexpr uint32_t max_load_height_in_elem =
// load_store_attr::max_load_height_in_elem;

static constexpr uint32_t elems_per_CL =
load_store_attr::cache_line_size_in_bytes / sizeof(dtype);

static constexpr uint32_t elems_per_reg =
register_bytes_t<arch_tag>::reg_in_bytes / sizeof(dtype);

static constexpr uint32_t ld_blk_size_y_limit =
mem_transpose ? max_trans_load_width_in_elem : max_load_height_in_elem;
static constexpr uint32_t ld_blk_size_y = reg_transpose
? block_size_y
: std::min(ld_blk_size_y_limit, block_size_y);
static constexpr uint32_t max_ld_blk_width_in_elem = trans
? load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype)
: load_store_attr::max_load_width_in_bytes / sizeof(dtype);

static constexpr uint32_t max_ld_blk_height_in_elem = trans
? load_store_attr::max_trans_load_height_in_elem
: load_store_attr::max_load_height_in_elem;

static constexpr uint32_t ld_blk_width = std::min(
(mem_transpose ? block_size_y : block_size_x), max_ld_blk_width_in_elem);
static constexpr uint32_t ld_blk_height = std::min(
(mem_transpose ? block_size_x : block_size_y), max_ld_blk_height_in_elem);

static constexpr uint32_t ld_blk_size_y =
mem_transpose ? ld_blk_width : ld_blk_height;

static constexpr uint32_t ld_blk_size_y_limit = mem_transpose
? load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype)
: load_store_attr::max_load_height_in_elem;
// static constexpr uint32_t ld_blk_size_y = reg_transpose
// ? block_size_y
// : std::min(ld_blk_size_y_limit, block_size_y);

// array len is used to make sure memory load is cache line aligned
// disabled while register or memory transpose
Expand Down Expand Up @@ -198,10 +216,7 @@ tile_load(tile_t& tile, payload_t& payload) {
constexpr uint32_t load_block_elems = block_elems * arr_len;
auto reg_blk = tile.reg.xetla_select<load_block_elems, 1>(
(i * num_block_x + j) * block_elems);
constexpr uint32_t ld_blk_height = (reg_transpose && trans)
? detail::getNextPowerOf2<ld_blk_size_y>()
: ld_blk_size_y;
constexpr uint32_t tmp_size = ld_blk_height * block_size_x * arr_len;
constexpr uint32_t tmp_size = ld_blk_width * ld_blk_height * arr_len;
xetla_vector<dtype, tmp_size> reg_tmp;
#pragma unroll
for (uint32_t ii = 0; ii < block_size_y / ld_blk_size_y; ++ii) {
Expand All @@ -213,10 +228,8 @@ tile_load(tile_t& tile, payload_t& payload) {
mem_transpose ? offset_x : (offset_y + ii * ld_blk_size_y);
reg_tmp.xetla_format<native_type_t<load_dtype>>() = xetla_load_global<
native_type_t<load_dtype>,
(trans ? ld_blk_size_y : block_size_x) / scale_factor,
(trans ? block_size_x : ld_blk_size_y),
// block_size_x / scale_factor,
// ld_blk_size_y,
ld_blk_width / scale_factor,
ld_blk_height,
arr_len,
trans,
mem_transform,
Expand Down Expand Up @@ -261,11 +274,6 @@ tile_load(tile_t& tile, payload_t& payload) {
(mem_transpose ? remained_blk_size_y : block_size_x) / scale_factor;
constexpr uint8_t block_height =
mem_transpose ? block_size_x : remained_blk_size_y;
// constexpr uint32_t block_widthx_widthy_arrlen =
// (block_width - 1) | ((block_height - 1) << 8);
// gpu::xetla::detail::xetla_set_block_widthx_widthy_arrlen(
// tdesc.xetla_format<uint32_t>(), block_widthx_widthy_arrlen);

reg_blk.xetla_select<load_elems, 1>(remained_start)
.xetla_format<native_type_t<load_dtype>>() = xetla_load_global<
native_type_t<load_dtype>,
Expand All @@ -283,15 +291,6 @@ tile_load(tile_t& tile, payload_t& payload) {
payload.surface_pitch,
payload.offset_x + offset_x / scale_factor,
payload.offset_y + offset_y + remained_start_y);

// xetla_tload_global<
// load_dtype,
// (load_elems / scale_factor),
// L1,
// L2,
// trans,
// mem_transform,
// arch_tag>(tdesc);
}
}
}
Expand All @@ -304,24 +303,16 @@ tile_load(tile_t& tile, payload_t& payload) {
(!reg_transpose && (remained_size_y > ld_blk_size_y_limit))
? ld_blk_size_y_limit
: remained_size_y;
// auto payload_row = payload_2d.xetla_select<num_block_x, 1, 16, 1>(
// num_block_y * num_block_x, 0);
// detail::reset_tile_desc_core<
// num_block_x,
// block_size_x,
// remained_ld_blk_size_y,
// scale_factor,
// arr_len,
// mem_transpose>(payload_row);

#pragma unroll
for (uint32_t j = 0; j < num_block_x; j += arr_len) {
int32_t offset_x = j * block_size_x;
// xetla_tdescriptor tdesc = payload_row.row(j);
auto reg_blk = tile.reg.xetla_select<remained_block_elems * arr_len, 1>(
processed_elems + j * remained_block_elems);
constexpr uint32_t ld_blk_height = (reg_transpose && trans)
? detail::getNextPowerOf2<remained_ld_blk_size_y>()
: remained_ld_blk_size_y;
// constexpr uint32_t ld_blk_height = (reg_transpose && trans)
// ? detail::getNextPowerOf2<remained_ld_blk_size_y>()
// : remained_ld_blk_size_y;
constexpr uint32_t tmp_size = ld_blk_height * block_size_x * arr_len;
xetla_vector<dtype, tmp_size> reg_tmp;
#pragma unroll
Expand Down
18 changes: 9 additions & 9 deletions tests/integration/gemm/fp32/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ TYPED_TEST_P(fp32_gemm_test, esimd) {

REGISTER_TYPED_TEST_SUITE_P(fp32_gemm_test, esimd);
using tests = ::testing::Types<
// Test1,
// Test2,
// Test3,
// Test4,
// Test5,
// Test6,
// Test7,
// Test8,
// Test9,
Test1,
Test2,
Test3,
Test4,
Test5,
Test6,
Test7,
Test8,
Test9,
Test10,
Test11>;
INSTANTIATE_TYPED_TEST_SUITE_P(fp32_gemm_test_suite, fp32_gemm_test, tests);

0 comments on commit a0b94f3

Please sign in to comment.