diff --git a/include/subgroup/tile/impl/load_xe.hpp b/include/subgroup/tile/impl/load_xe.hpp index 8d958ea9d..1b0d83e74 100644 --- a/include/subgroup/tile/impl/load_xe.hpp +++ b/include/subgroup/tile/impl/load_xe.hpp @@ -108,21 +108,22 @@ tile_load(tile_t& tile, payload_t& payload) { using load_store_attr = load_store_attr_t; - // static constexpr uint32_t max_load_width_in_elem = trans - // ? load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype) - // : load_store_attr::max_load_width_in_bytes / sizeof(dtype); + // static constexpr uint32_t max_load_width_in_elem = trans + // ? load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype) + // : load_store_attr::max_load_width_in_bytes / sizeof(dtype); // static constexpr uint32_t max_load_height_in_elem = trans // ? load_store_attr::max_trans_load_height_in_elem // : load_store_attr::max_load_height_in_elem; - static constexpr uint32_t max_trans_load_width_in_elem = - load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype); - static constexpr uint32_t max_load_width_in_elem = - load_store_attr::max_load_width_in_bytes / sizeof(dtype); + // static constexpr uint32_t max_trans_load_width_in_elem = + // load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype); + // static constexpr uint32_t max_load_width_in_elem = + // load_store_attr::max_load_width_in_bytes / sizeof(dtype); // static constexpr uint32_t max_trans_load_height_in_elem = // load_store_attr::max_trans_load_height_in_elem; - static constexpr uint32_t max_load_height_in_elem = - load_store_attr::max_load_height_in_elem; + + // static constexpr uint32_t max_load_height_in_elem = + // load_store_attr::max_load_height_in_elem; static constexpr uint32_t elems_per_CL = load_store_attr::cache_line_size_in_bytes / sizeof(dtype); @@ -130,24 +131,38 @@ tile_load(tile_t& tile, payload_t& payload) { static constexpr uint32_t elems_per_reg = register_bytes_t::reg_in_bytes / sizeof(dtype); - static constexpr uint32_t ld_blk_size_y_limit = - mem_transpose ? max_trans_load_width_in_elem : max_load_height_in_elem; - static constexpr uint32_t ld_blk_size_y = reg_transpose - ? block_size_y - : std::min(ld_blk_size_y_limit, block_size_y); + static constexpr uint32_t max_load_width_in_elem = trans + ? load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype) + : load_store_attr::max_load_width_in_bytes / sizeof(dtype); + + static constexpr uint32_t max_load_blk_height_in_elem = trans + ? load_store_attr::max_trans_load_height_in_elem + : load_store_attr::max_load_height_in_elem; + + static constexpr uint32_t ld_blk_width = std::min( + (mem_transpose ? block_size_y : block_size_x), max_load_width_in_elem); + + static constexpr uint32_t ld_blk_height = std::min( + (mem_transpose ? block_size_x : block_size_y), + max_load_blk_height_in_elem); + + static constexpr uint32_t ld_blk_size_y = + mem_transpose ? ld_blk_width : ld_blk_height; + + static constexpr uint32_t ld_blk_size_y_limit = mem_transpose + ? load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype) + : load_store_attr::max_load_height_in_elem; // array len is used to make sure memory load is cache line aligned // disabled while register or memory transpose static constexpr uint8_t arr_len_candidate = - (reg_transpose || - mem_transpose + ((reg_transpose || mem_transpose) // block elements should be integer // times of register bytes - || ((block_size_y * block_size_x) % elems_per_reg != 0) + || ((block_elems) % elems_per_reg != 0) // tail blocks also need to meet above condition - || - (((tile_size_y % block_size_y) * block_size_x) % elems_per_reg != 0)) || - (block_size_y > ld_blk_size_y_limit) + || (((tile_size_y % block_size_y) * block_size_x) % elems_per_reg != 0)) + // || (block_size_y > load_store_attr::max_load_height_in_elem) ? 1 : (((tile_size_x % elems_per_CL) == 0) ? (((elems_per_CL % block_size_x) == 0) @@ -155,34 +170,24 @@ tile_load(tile_t& tile, payload_t& payload) { : 1) : ((tile_size_x < elems_per_CL) ? (tile_size_x / block_size_x) : 1)); - static constexpr bool is_valid_arr_len_candidate = (arr_len_candidate == 1) || - (arr_len_candidate == 2) || (arr_len_candidate == 4); - - static constexpr uint8_t arr_len = - is_valid_arr_len_candidate ? arr_len_candidate : 1; - - static_assert( - reg_transpose || mem_transpose || - (!mem_transpose && - (block_size_x * arr_len) <= max_load_width_in_elem), - "When reg_transpose was disabled, check 2d block width " - "restriction"); - static_assert( - !reg_transpose || - (!mem_transpose && - (block_size_x * arr_len) <= max_trans_load_width_in_elem) || - (mem_transpose && (block_size_y * arr_len) <= max_load_width_in_elem), - "When reg_transpose was enabled, check 2d block width " - "restriction"); - static_assert( - !reg_transpose || - (!mem_transpose && (block_size_y <= max_load_height_in_elem)) || - (mem_transpose && (block_size_x) <= max_load_height_in_elem), - "When reg_transpose was enabled, check 2d block height " - "restriction"); - static_assert( - tile_size_x % (block_size_x * arr_len) == 0, - "tile_size_x should be a multiple of (block_size_x * arr_len)"); + // NBlocks must be {1,2,4} for bytes and words, {1,2} for dwords, 1 for + // qwords. + static constexpr bool arr_len = + ((arr_len_candidate == 1) || + (arr_len_candidate == 2 && sizeof(dtype) <= 4) || + (arr_len_candidate == 4 && sizeof(dtype) <= 2)) + ? arr_len_candidate + : 1; + + if constexpr (!trans && !mem_transform) { + static_assert( + (ld_blk_width * arr_len) <= max_load_width_in_elem, + "When Transposed and Transformed are both set to false, BlockWidth * NBlocks must not exceed 64 for bytes, 32 for words, 16 for dwords, and 8 for qwords"); + } else if constexpr (mem_transform) { + static_assert( + (ld_blk_width * arr_len) <= max_load_width_in_elem, + "When Transformed is true then, BlockWidth * NBlocks must not exceed 64 for bytes and 32 for words."); + } static_assert( (reg_transpose && ((block_size_x * sizeof(dtype)) % sizeof(load_dtype) == 0)) || @@ -198,10 +203,7 @@ tile_load(tile_t& tile, payload_t& payload) { constexpr uint32_t load_block_elems = block_elems * arr_len; auto reg_blk = tile.reg.xetla_select( (i * num_block_x + j) * block_elems); - constexpr uint32_t ld_blk_height = (reg_transpose && trans) - ? detail::getNextPowerOf2() - : ld_blk_size_y; - constexpr uint32_t tmp_size = ld_blk_height * block_size_x * arr_len; + constexpr uint32_t tmp_size = ld_blk_width * ld_blk_height * arr_len; xetla_vector reg_tmp; #pragma unroll for (uint32_t ii = 0; ii < block_size_y / ld_blk_size_y; ++ii) { @@ -213,10 +215,8 @@ tile_load(tile_t& tile, payload_t& payload) { mem_transpose ? offset_x : (offset_y + ii * ld_blk_size_y); reg_tmp.xetla_format>() = xetla_load_global< native_type_t, - (trans ? ld_blk_size_y : block_size_x) / scale_factor, - (trans ? block_size_x : ld_blk_size_y), - // block_size_x / scale_factor, - // ld_blk_size_y, + ld_blk_width / scale_factor, + ld_blk_height, arr_len, trans, mem_transform, @@ -261,11 +261,6 @@ tile_load(tile_t& tile, payload_t& payload) { (mem_transpose ? remained_blk_size_y : block_size_x) / scale_factor; constexpr uint8_t block_height = mem_transpose ? block_size_x : remained_blk_size_y; - // constexpr uint32_t block_widthx_widthy_arrlen = - // (block_width - 1) | ((block_height - 1) << 8); - // gpu::xetla::detail::xetla_set_block_widthx_widthy_arrlen( - // tdesc.xetla_format(), block_widthx_widthy_arrlen); - reg_blk.xetla_select(remained_start) .xetla_format>() = xetla_load_global< native_type_t, @@ -283,15 +278,6 @@ tile_load(tile_t& tile, payload_t& payload) { payload.surface_pitch, payload.offset_x + offset_x / scale_factor, payload.offset_y + offset_y + remained_start_y); - - // xetla_tload_global< - // load_dtype, - // (load_elems / scale_factor), - // L1, - // L2, - // trans, - // mem_transform, - // arch_tag>(tdesc); } } } @@ -304,24 +290,16 @@ tile_load(tile_t& tile, payload_t& payload) { (!reg_transpose && (remained_size_y > ld_blk_size_y_limit)) ? ld_blk_size_y_limit : remained_size_y; - // auto payload_row = payload_2d.xetla_select( - // num_block_y * num_block_x, 0); - // detail::reset_tile_desc_core< - // num_block_x, - // block_size_x, - // remained_ld_blk_size_y, - // scale_factor, - // arr_len, - // mem_transpose>(payload_row); + #pragma unroll for (uint32_t j = 0; j < num_block_x; j += arr_len) { int32_t offset_x = j * block_size_x; // xetla_tdescriptor tdesc = payload_row.row(j); auto reg_blk = tile.reg.xetla_select( processed_elems + j * remained_block_elems); - constexpr uint32_t ld_blk_height = (reg_transpose && trans) - ? detail::getNextPowerOf2() - : remained_ld_blk_size_y; + // constexpr uint32_t ld_blk_height = (reg_transpose && trans) + // ? detail::getNextPowerOf2() + // : remained_ld_blk_size_y; constexpr uint32_t tmp_size = ld_blk_height * block_size_x * arr_len; xetla_vector reg_tmp; #pragma unroll @@ -490,7 +468,8 @@ tile_load(tile_t& tile, payload_t& payload) { /// @brief This function loads data from unaligned-2D memory surface. /// Loads an array of rectangular regions (X,Y)..(X+W,Y+H) from memory into -/// registers. Each block will be loaded serially by its corresponding payload. +/// registers. Each block will be loaded serially by its corresponding +/// payload. /// @tparam tile_t Is the tile_t struct contains registers. /// These registers will be the destination of load operation. /// @tparam payload_t Is the mem_payload_t struct describing the memory @@ -614,7 +593,8 @@ tile_load(tile_t& tile, payload_t& payload) { /// @brief This function loads data from unaligned-2D memory surface. /// Loads an array of rectangular regions (X,Y)..(X+W,Y+H) from memory into -/// registers. Each block will be loaded serially by its corresponding payload. +/// registers. Each block will be loaded serially by its corresponding +/// payload. /// @tparam tile_t Is the tile_t struct contains registers. /// These registers will be the destination of load operation. /// @tparam payload_t Is the mem_payload_t struct describing the memory @@ -679,7 +659,8 @@ tile_load(tile_t& tile, payload_t& payload) { /// @brief This function loads data from unaligned-2D memory surface. /// Loads an array of rectangular regions (X,Y)..(X+W,Y+H) from memory into -/// registers. Each block will be loaded serially by its corresponding payload. +/// registers. Each block will be loaded serially by its corresponding +/// payload. /// @tparam tile_t Is the tile_t struct contains registers. /// These registers will be the destination of load operation. /// @tparam payload_t Is the mem_payload_t struct describing the memory @@ -819,8 +800,8 @@ tile_load( } /// @brief Is the data load func from local shared memory to register file, -/// which supports the memory surface is 1d or 2d scenario. And we always assume -/// data in SLM is row major. +/// which supports the memory surface is 1d or 2d scenario. And we always +/// assume data in SLM is row major. /// @tparam tile_t Is the tile_t struct contains registers /// These registers will be the destination of load operation. /// @tparam payload_t Is the mem_payload_t struct describing the memory @@ -902,8 +883,8 @@ tile_load(tile_t& tile, payload_t& payload) { } /// @brief Is the data load func from shared local memory to register file, -/// which supports the memory surface is 1d scenario. And the src memory layout -/// is always row major. +/// which supports the memory surface is 1d scenario. And the src memory +/// layout is always row major. /// @tparam tile_t Is the tile_t struct contains registers. /// These registers will be the destination of load operation. /// @tparam payload_t Is the mem_payload_t struct describing the memory diff --git a/tests/integration/gemm/fp32/main.cpp b/tests/integration/gemm/fp32/main.cpp index a05393399..3a9badc9f 100644 --- a/tests/integration/gemm/fp32/main.cpp +++ b/tests/integration/gemm/fp32/main.cpp @@ -34,15 +34,15 @@ TYPED_TEST_P(fp32_gemm_test, esimd) { REGISTER_TYPED_TEST_SUITE_P(fp32_gemm_test, esimd); using tests = ::testing::Types< - // Test1, - // Test2, - // Test3, - // Test4, - // Test5, - // Test6, - // Test7, - // Test8, - // Test9, + Test1, + Test2, + Test3, + Test4, + Test5, + Test6, + Test7, + Test8, + Test9, Test10, Test11>; INSTANTIATE_TYPED_TEST_SUITE_P(fp32_gemm_test_suite, fp32_gemm_test, tests);