Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
fix perchanel fp16 load
Browse files Browse the repository at this point in the history
  • Loading branch information
sunjiweiswift committed Aug 29, 2024
1 parent 4b56d54 commit 82e9ce2
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 33 deletions.
21 changes: 3 additions & 18 deletions include/common/core/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -476,24 +476,9 @@ __XETLA_API xetla_vector<T, N> xetla_load_global(
Y);
return ret.xetla_format<T>();
} else if constexpr (BlockWidth * sizeof(T) < sizeof(uint32_t)) {
constexpr auto scale_factor = sizeof(uint32_t) / sizeof(T);
xetla_vector<uint32_t, N> ret = xetla_load_global<
uint32_t,
BlockWidth,
BlockHeight,
NBlocks,
Transposed,
Transformed,
L1H,
L2H>(
reinterpret_cast<const uint32_t*>(Ptr),
SurfaceWidth,
SurfaceHeight,
SurfacePitch,
X / scale_factor,
Y);
return ret.xetla_format<T>().xetla_select<N, scale_factor>(
X % scale_factor);
xetla_vector<uint32_t, BlockHeight> byte_offsets =
xetla_vector_gen<uint32_t, BlockHeight>(0, SurfacePitch);
return xetla_load_global<T, N, BlockWidth, L1H, L2H>(Ptr, byte_offsets);
} else {
return __ESIMD_ENS::lsc_load_2d<
T,
Expand Down
21 changes: 6 additions & 15 deletions include/subgroup/tile/impl/load_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,10 +236,10 @@ tile_load(tile_t& tile, payload_t& payload) {
reg_tmp
.xetla_format<
native_type_t<load_dtype>,
block_size_x / scale_factor,
ld_blk_width / scale_factor,
ld_blk_height>()
.xetla_select<
block_size_x / scale_factor,
ld_blk_width / scale_factor,
1,
ld_blk_size_y,
1>(0, 0);
Expand Down Expand Up @@ -297,9 +297,9 @@ tile_load(tile_t& tile, payload_t& payload) {
// xetla_tdescriptor tdesc = payload_row.row(j);
auto reg_blk = tile.reg.xetla_select<remained_block_elems * arr_len, 1>(
processed_elems + j * remained_block_elems);
// constexpr uint32_t ld_blk_height = (reg_transpose && trans)
// ? detail::getNextPowerOf2<remained_ld_blk_size_y>()
// : remained_ld_blk_size_y;
constexpr uint32_t ld_blk_height = (reg_transpose && trans)
? detail::getNextPowerOf2<remained_ld_blk_size_y>()
: remained_ld_blk_size_y;
constexpr uint32_t tmp_size = ld_blk_height * block_size_x * arr_len;
xetla_vector<dtype, tmp_size> reg_tmp;
#pragma unroll
Expand All @@ -311,7 +311,7 @@ tile_load(tile_t& tile, payload_t& payload) {
reg_tmp.xetla_format<native_type_t<load_dtype>>() = xetla_load_global<
native_type_t<load_dtype>,
block_size_x / scale_factor,
ld_blk_height,
remained_ld_blk_size_y,
arr_len,
trans,
mem_transform,
Expand All @@ -325,15 +325,6 @@ tile_load(tile_t& tile, payload_t& payload) {
payload.offset_x + offset_x / scale_factor,
payload.offset_y + num_block_y * block_size_y +
ii * remained_ld_blk_size_y);
// xetla_tload_global<
// load_dtype,
// (ld_blk_height * block_size_x * arr_len / scale_factor),
// L1,
// L2,
// trans,
// mem_transform,
// arch_tag>(tdesc);

if constexpr (reg_transpose && trans) {
reg_blk.xetla_select<load_elems, 1>(ii * load_elems)
.xetla_format<native_type_t<load_dtype>>() =
Expand Down

0 comments on commit 82e9ce2

Please sign in to comment.