Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

[BesTLA] Sync compiler's compatibility #279

Merged
merged 13 commits into from
Jun 4, 2024
16 changes: 11 additions & 5 deletions bestla/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
cmake_minimum_required(VERSION 3.12)

project(bestla LANGUAGES CXX VERSION 0.1.0)

include(cmake/FindSIMD.cmake)

file(GLOB headers ${PROJECT_NAME}/*.h ${PROJECT_NAME}/*.hpp)
file(GLOB xbyak_headers ${PROJECT_NAME}/xbyak/*.h ${PROJECT_NAME}/xbyak/*.hpp)

option(BTLA_ENABLE_OPENMP "Compile OpenMP thread pool if OMP can be found" OFF)
option(BTLA_SYCL "Compile OpenMP thread pool if OMP can be found" OFF)
Expand All @@ -22,11 +23,16 @@ option(BTLA_UT_NOASAN "Disable sanitize" OFF)
option(BTLA_UT_BENCHMARK "Benchmark ON may take a long time to finish all tests" OFF)
option(BTLA_UT_OPENMP "Use OpenMP for UT tests" OFF)




include(FetchContent)
FetchContent_Declare(
xbyak
GIT_REPOSITORY https://github.com/herumi/xbyak.git
GIT_TAG v7.06
)
FetchContent_MakeAvailable(xbyak)

add_library(${PROJECT_NAME} INTERFACE)
target_link_libraries(${PROJECT_NAME} INTERFACE xbyak)
add_library(neural_speed::${PROJECT_NAME} ALIAS ${PROJECT_NAME})
target_include_directories(
${PROJECT_NAME} INTERFACE
Expand Down
8 changes: 4 additions & 4 deletions bestla/bestla/bestla_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -496,10 +496,10 @@ class CpuRuntime {

inline void adjustPE(const BTLA_ISA isa, const float PE_) {
// printf("Adjust:%d,%f\n",int(isa),PE_);
PE[int(isa)] = PE[int(isa)] * PE_ * 0.7 + PE[int(isa)] * 0.3;
PE[int(isa)] = PE[int(isa)] * PE_ * 0.7f + PE[int(isa)] * 0.3f;
}

size_t mL2Cache, mL1Cache, mL2Cache_P = 0, mL1Cache_P = 0, mL2Cache_E = 0, mL1Cache_E = 0;
size_t mL2Cache = 0, mL1Cache = 0, mL2Cache_P = 0, mL1Cache_P = 0, mL2Cache_E = 0, mL1Cache_E = 0;
int P_core_num = 0, E_core_num = 0;
bool mHybrid = false;

Expand Down Expand Up @@ -530,8 +530,8 @@ class CpuRuntime {
}
}
}
float PE[int(BTLA_ISA::ISA_COUNT)];
int maxThreads;
float PE[int(BTLA_ISA::ISA_COUNT)] = {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f};
int maxThreads = 0;
};
} // namespace device
} // namespace bestla
6 changes: 4 additions & 2 deletions bestla/bestla/bestla_prologue_b.h
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,8 @@ class WeightKBlockNInteger {
for (int i = thdp.loc[1]; i < thdp.loc[1] + thdp.size[1]; i++) {
if (i < rawnk_scale) {
for (int j = 0; j < N; j++) {
stor->template SPtr<utils::f8>()[i * stor->mNPad + j] = scales[j * rawnk_scale + i];
stor->template SPtr<utils::f8>()[i * stor->mNPad + j] =
static_cast<int>(scales[j * rawnk_scale + i]);
}
} else {
std::memset(stor->template SPtr<utils::f8>() + i * stor->mNPad, 0, stor->mNPad * sizeof(utils::f8));
Expand Down Expand Up @@ -771,14 +772,15 @@ class WeightKBlockNInteger {
if (wptr->mDType == BTLA_DTYPE::S4_CLIP) {
if (wptr->SDtype() == BTLA_DTYPE::DQ8_BNB) {
auto internal_n_offset = n_offset + i;
int dq_offset = static_cast<int>(wptr->mCorrection.mDQCorrectionBuf.mBufSize / sizeof(float) - 1);
kernel::wrapper::DecompressDQKBlockS4Fp<_T, _GemmCore_T::PACK_ROW>::template forward<ISA_T,
BTLA_DTYPE::S4_CLIP>(
wptr->template WPtr<utils::int4x2>() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 +
i * KPad / 2,
*dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize,
wptr->template SPtr<uint8_t>(), wptr->template DQPtr<float>(), k_offset / _GemmCore_T::PACK_ROW,
internal_n_offset, wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, wptr->mN, wptr->mDqBlockSize,
wptr->mCorrection.mDQCorrectionBuf.mBufSize / sizeof(float) - 1, tmpcache, cachesize);
dq_offset, tmpcache, cachesize);
} else {
auto sptr = wptr->template SPtr<void>();
kernel::wrapper::DecompressKBlockS4Fp<_GemmCore_T::PACK_ROW, _GemmCore_T::NTILE, _T>::template forward<ISA_T>(
Expand Down
45 changes: 10 additions & 35 deletions bestla/bestla/bestla_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,41 +60,16 @@

// As long as the compiler supports the ISA, we will enable it.
// Only the ISA you use in your project will be compiled.
#ifdef __GNUC__
#define CompileAVX512F() (__GNUC__ >= 6)
#define CompileAVX512VNNI() (__GNUC__ >= 9)
#define CompileAVX2() (__GNUC__ >= 5)
#define CompileAVXVNNI() (__GNUC__ >= 11)
#define CompileAMX() (__GNUC__ >= 11)
#define CompileBF16() (__GNUC__ >= 11)
#define CompileFP16() (__GNUC__ >= 13)
#define CompileAMXBF16() (CompileAMX())
#define CompileAMXINT8() (CompileAMX())
#endif

#if defined(_MSC_VER) && !defined(__INTEL_LLVM_COMPILER)
#define CompileAVX512F() _MSC_VER && (_MSC_VER >= 1911)
#define CompileAVX512VNNI() _MSC_VER && (_MSC_VER >= 1930) // TODO(Yu) check the minimum version
#define CompileAVX2() _MSC_VER && (_MSC_VER >= 1900)
#define CompileAVXVNNI() _MSC_VER && (_MSC_VER >= 1930) // TODO(Yu) check the minimum version
#define CompileAMX() _MSC_VER && (_MSC_VER >= 1930) // TODO(Yu) check the minimum version
#define CompileBF16() _MSC_VER && (_MSC_VER >= 1938) // TODO(Yu) check the minimum version
#define CompileFP16() _MSC_VER && (_MSC_VER >= 1938) // TODO(Yu) check the minimum version
#define CompileAMXBF16() (CompileAMX())
#define CompileAMXINT8() (CompileAMX())
#endif

#if defined(_MSC_VER) && defined(__INTEL_LLVM_COMPILER)
#define CompileAVX512F() defined(__AVX512F__)
#define CompileAVX512VNNI() defined(__AVX512VNNI__)
#define CompileAVX2() defined(__AVX2__) && defined(__F16C__) && defined(__FMA__)
#define CompileAVXVNNI() defined(__AVXVNNI__)
#define CompileAMX() defined(__AMX_TILE__)
#define CompileBF16() defined(__AVX512BF16__)
#define CompileFP16() defined(__AVX512FP16__)
#define CompileAMXBF16() (CompileAMX())
#define CompileAMXINT8() (CompileAMX())
#endif
#define CompileAVX512F() BTLA_AVX512_FOUND
#define CompileAVX512VNNI() BTLA_AVX512_VNNI_FOUND
#define CompileAVX2() BTLA_AVX2_FOUND
#define CompileAVXVNNI() BTLA_AVX_VNNI_FOUND
#define CompileBF16() BTLA_AVX512_BF16_FOUND
#define CompileFP16() BTLA_AVX512_FP16_FOUND
#define CompileAMXBF16() BTLA_AMX_BF16_FOUND
#define CompileAMXFP16() BTLA_AMX_FP16_FOUND
#define CompileAMXINT8() BTLA_AMX_INT8_FOUND
#define CompileAMX() BTLA_AMX_BF16_FOUND

// called by launcher, time critical functions
#define TLACALL \
Expand Down
2 changes: 2 additions & 0 deletions bestla/bestla/bestla_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ class LauncherBase {
} else {
gemm(_param, _config);
}
bestla::kernel::wrapper::ZeroReg::forward();
}

protected:
Expand Down Expand Up @@ -709,6 +710,7 @@ class LauncherIntKBlock {
} else {
gemm(_param, _config);
}
bestla::kernel::wrapper::ZeroReg::forward();
}

protected:
Expand Down
33 changes: 18 additions & 15 deletions bestla/bestla/kernel_avx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ namespace avx2 {
#pragma clang attribute push(__attribute__((target("avx,avx2,fma"))), apply_to = function)
#endif

static inline void zero_reg() { _mm256_zeroupper(); }

static inline __m256i unpack_4bits(void* srcptr, __m256i mask) {
auto raw_data = _mm_loadu_si128(reinterpret_cast<__m128i*>(srcptr));
auto ymm0 = _mm256_cvtepu8_epi16(raw_data);
Expand Down Expand Up @@ -539,7 +541,8 @@ static inline BTLA_CODE decompress_s4_s8(utils::int4x2* srcptr, int8_t* dstptr,
vout_y = _mm256_sub_epi8(vout_y, vbias);
_mm256_storeu_si256((__m256i*)(dstptr + i), vout_y);
} else {
ref::decompress_kblock_s4_s8<1, 1>(srcptr + i / 2, nullptr, dstptr + i, 0, 0, 0, 0, 1, elesize - i, nullptr, 0);
ref::decompress_kblock_s4_s8<1, 1>(srcptr + i / 2, nullptr, dstptr + i, 0, 0, 0, 0, 1,
static_cast<int>(elesize - i), nullptr, 0);
}
}
return BTLA_CODE::Success;
Expand Down Expand Up @@ -732,15 +735,15 @@ static inline BTLA_CODE decompress_s2_s8(utils::bit2x4* bit2ptr, int8_t* dstptr,
size_t tmpsize) {
int constexpr VBits = 256;
int constexpr VElt = VBits / 8;
int i = 0;
size_t i = 0;
uint64_t mask0 = 0x0303030303030303;
auto vmask0 = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0);
auto vbias = _mm256_set1_epi8(2);
auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0);
auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2,
13, 9, 5, 1, 12, 8, 4, 0);
auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0);
int elt_pad = utils::padto_le(unpack_elt, VElt);
size_t elt_pad = utils::padto_le(unpack_elt, VElt);
for (; i < elt_pad; i += VElt) {
auto vout = unpack_2bits(bit2ptr + i / 4, vshift_y, vmask0, vsfhl_mask_y, vorder_y);
vout = _mm256_sub_epi8(vout, vbias);
Expand Down Expand Up @@ -981,7 +984,7 @@ static inline BTLA_CODE decompress_s3_s8(utils::bit2x4* bit2ptr, utils::bit1x8*
size_t unpack_elt, int8_t* tmp, size_t tmpsize) {
int constexpr VBits = 256;
int constexpr VElt = VBits / 8;
int i = 0;
size_t i = 0;
uint64_t mask0 = 0x0303030303030303;
auto vmask0 = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0);
auto vbias = _mm256_set1_epi8(4);
Expand All @@ -994,7 +997,7 @@ static inline BTLA_CODE decompress_s3_s8(utils::bit2x4* bit2ptr, utils::bit1x8*
const __m256i bit1Mask = _mm256_set1_epi32(0x0F);
const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0);
const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2));
int elt_pad = utils::padto_le(unpack_elt, VElt);
size_t elt_pad = utils::padto_le(unpack_elt, VElt);
for (; i < elt_pad; i += VElt) {
auto vout = unpack_2bits(bit2ptr + i / 4, vshift_y, vmask0, vsfhl_mask_y, vorder_y);
auto vb1 = unpack_1bits(bit1ptr + i / 8, bit1Shift_1, bit1Mask, bit1Shift_2, highMask);
Expand Down Expand Up @@ -1213,15 +1216,15 @@ static inline BTLA_CODE decompress_s1_s8(utils::bit1x8* bit1ptr, int8_t* dstptr,
size_t tmpsize) {
int constexpr VBits = 256;
int constexpr VElt = VBits / 8;
int i = 0;
size_t i = 0;
int constexpr FullRange = 1 << (1 - 1);
auto vbias = _mm256_set1_epi8(FullRange);

const __m256i highMask = _mm256_set1_epi8(0x04);
const __m256i bit1Mask = _mm256_set1_epi32(0x0F);
const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0);
const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2));
int elt_pad = utils::padto_le(unpack_elt, VElt);
size_t elt_pad = utils::padto_le(unpack_elt, VElt);
for (; i < elt_pad; i += VElt) {
auto vb1 = unpack_1bits(bit1ptr + i / 8, bit1Shift_1, bit1Mask, bit1Shift_2, highMask);
vb1 = _mm256_srli_epi32(vb1, 2);
Expand Down Expand Up @@ -1460,7 +1463,7 @@ static inline BTLA_CODE decompress_s5_s8(utils::bit4x2* bit4ptr, utils::bit1x8*
size_t unpack_elt, int8_t* tmp, size_t tmpsize) {
int constexpr VBits = 256;
int constexpr VElt = VBits / 8;
int i = 0;
size_t i = 0;
int constexpr FullRange = 1 << (5 - 1);
uint32_t mask = 0x0f0f0f0f;
auto vmask = _mm256_set1_epi32(*reinterpret_cast<int*>(&mask));
Expand All @@ -1470,7 +1473,7 @@ static inline BTLA_CODE decompress_s5_s8(utils::bit4x2* bit4ptr, utils::bit1x8*
const __m256i bit1Mask = _mm256_set1_epi32(0x0F);
const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0);
const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2));
int elt_pad = utils::padto_le(unpack_elt, VElt);
size_t elt_pad = utils::padto_le(unpack_elt, VElt);
for (; i < elt_pad; i += VElt) {
auto vout = unpack_4bits(bit4ptr + i / 2, vmask);
auto vb1 = unpack_1bits(bit1ptr + i / 8, bit1Shift_1, bit1Mask, bit1Shift_2, highMask);
Expand Down Expand Up @@ -1760,7 +1763,7 @@ static inline BTLA_CODE decompress_s7_s8(utils::bit4x2* bit4ptr, utils::bit2x4*
int8_t* dstptr, size_t unpack_elt, int8_t* tmp, size_t tmpsize) {
int constexpr VBits = 256;
int constexpr VElt = VBits / 8;
int i = 0;
size_t i = 0;
int constexpr FullRange = 1 << (7 - 1);
uint32_t mask = 0x0f0f0f0f;
auto vmask = _mm256_set1_epi32(*reinterpret_cast<int*>(&mask));
Expand All @@ -1777,7 +1780,7 @@ static inline BTLA_CODE decompress_s7_s8(utils::bit4x2* bit4ptr, utils::bit2x4*
const __m256i bit1Mask = _mm256_set1_epi32(0x0F);
const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0);
const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2));
int elt_pad = utils::padto_le(unpack_elt, VElt);
size_t elt_pad = utils::padto_le(unpack_elt, VElt);
for (; i < elt_pad; i += VElt) {
auto vout = unpack_4bits(bit4ptr + i / 2, vmask);
auto vb1 = unpack_1bits(bit1ptr + i / 8, bit1Shift_1, bit1Mask, bit1Shift_2, highMask);
Expand Down Expand Up @@ -2035,7 +2038,7 @@ static inline BTLA_CODE decompress_s6_s8(utils::bit4x2* bit4ptr, utils::bit2x4*
size_t unpack_elt, int8_t* tmp, size_t tmpsize) {
int constexpr VBits = 256;
int constexpr VElt = VBits / 8;
int i = 0;
size_t i = 0;
int constexpr FullRange = 1 << (6 - 1);
uint32_t mask = 0x0f0f0f0f;
auto vmask = _mm256_set1_epi32(*reinterpret_cast<int*>(&mask));
Expand All @@ -2047,7 +2050,7 @@ static inline BTLA_CODE decompress_s6_s8(utils::bit4x2* bit4ptr, utils::bit2x4*
auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2,
13, 9, 5, 1, 12, 8, 4, 0);
auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0);
int elt_pad = utils::padto_le(unpack_elt, VElt);
size_t elt_pad = utils::padto_le(unpack_elt, VElt);
for (; i < elt_pad; i += VElt) {
auto vout = unpack_4bits(bit4ptr + i / 2, vmask);
auto vb1 = unpack_2bits(bit2ptr + i / 4, vshift_y, vmask0, vsfhl_mask_y, vorder_y);
Expand Down Expand Up @@ -3474,8 +3477,8 @@ inline __m256 exp_ps_0_1(const __m256 x) {
static const auto log2e = _mm256_set1_ps(v_log2e);
static const auto half = _mm256_set1_ps(.5f);

static const auto upper_bound = _mm256_set1_ps(88.722838); // log(max_positive_float)
static const auto lower_bound = _mm256_set1_ps(-87.336549); // log(min_positive_float)
static const auto upper_bound = _mm256_set1_ps(88.722838f); // log(max_positive_float)
static const auto lower_bound = _mm256_set1_ps(-87.336549f); // log(min_positive_float)
__m256 x1 = _mm256_min_ps(x, upper_bound);
x1 = _mm256_max_ps(x1, lower_bound);

Expand Down
25 changes: 13 additions & 12 deletions bestla/bestla/kernel_avx512f.h
Original file line number Diff line number Diff line change
Expand Up @@ -912,7 +912,7 @@ static inline BTLA_CODE quantize_f32_sign_int_rowblock_sym_auto(const float* src
NVal = sum > 0.f ? -FullValue : FullValue;
}
NVal = NVal << (8 - NBits);
tmp_abs[iv] = NVal;
tmp_abs[iv] = static_cast<float>(NVal);
}
auto vmag = _mm512_loadu_ps(tmp_abs);
vscale = _mm512_div_ps(vabsval, vmag);
Expand Down Expand Up @@ -2463,7 +2463,8 @@ static inline BTLA_CODE decompress_s4_s8(utils::int4x2* srcptr, int8_t* dstptr,
vout_y = _mm512_sub_epi8(vout_y, vbias);
_mm512_storeu_si512((__m512i*)(dstptr + i), vout_y);
} else {
ref::decompress_kblock_s4_s8<1, 1>(srcptr + i / 2, nullptr, dstptr + i, 0, 0, 0, 0, 1, elesize - i, nullptr, 0);
ref::decompress_kblock_s4_s8<1, 1>(srcptr + i / 2, nullptr, dstptr + i, 0, 0, 0, 0, 1,
static_cast<int>(elesize - i), nullptr, 0);
}
}
return BTLA_CODE::Success;
Expand Down Expand Up @@ -2649,7 +2650,7 @@ static inline BTLA_CODE decompress_s2_s8(utils::bit2x4* bit2ptr, int8_t* dstptr,
size_t tmpsize) {
int constexpr VBits = 512;
int constexpr VElt = VBits / 8;
int i = 0;
size_t i = 0;
uint64_t mask0 = 0x0303030303030303;
auto vmask0 = _mm512_set1_epi64(*(int64_t*)&mask0);
auto vbias = _mm512_set1_epi8(2);
Expand All @@ -2658,7 +2659,7 @@ static inline BTLA_CODE decompress_s2_s8(utils::bit2x4* bit2ptr, int8_t* dstptr,
13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0,
15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0);
auto vorder_y = _mm512_set_epi32(3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0);
int elt_pad = utils::padto_le(unpack_elt, VElt);
size_t elt_pad = utils::padto_le(unpack_elt, VElt);
for (; i < elt_pad; i += VElt) {
auto vout = unpack_2bits(bit2ptr + i / 4, vshift_y, vmask0, vsfhl_mask_y, vorder_y);
vout = _mm512_sub_epi8(vout, vbias);
Expand Down Expand Up @@ -2871,7 +2872,7 @@ static inline BTLA_CODE decompress_s3_s8(utils::bit2x4* bit2ptr, utils::bit1x8*
size_t unpack_elt, int8_t* tmp, size_t tmpsize) {
int constexpr VBits = 512;
int constexpr VElt = VBits / 8;
int i = 0;
size_t i = 0;
uint64_t mask0 = 0x0303030303030303;
auto vmask0 = _mm512_set1_epi64(*(int64_t*)&mask0);
auto vbias = _mm512_set1_epi8(4);
Expand All @@ -2883,7 +2884,7 @@ static inline BTLA_CODE decompress_s3_s8(utils::bit2x4* bit2ptr, utils::bit1x8*

auto zmm_0x04 = _mm512_set1_epi8(0x04);
auto zmm_0x00 = _mm512_set1_epi8(0x00);
int elt_pad = utils::padto_le(unpack_elt, VElt);
size_t elt_pad = utils::padto_le(unpack_elt, VElt);
for (; i < elt_pad; i += VElt) {
auto vout = unpack_2bits(bit2ptr + i / 4, vshift_y, vmask0, vsfhl_mask_y, vorder_y);
auto vb1 = unpack_1bits(bit1ptr + i / 8, zmm_0x00, zmm_0x04);
Expand Down Expand Up @@ -3131,13 +3132,13 @@ static inline BTLA_CODE decompress_s1_s8(utils::bit1x8* bit1ptr, int8_t* dstptr,
size_t tmpsize) {
int constexpr VBits = 512;
int constexpr VElt = VBits / 8;
int i = 0;
size_t i = 0;
int constexpr FullRange = 1 << (1 - 1);
auto vbias = _mm512_set1_epi8(FullRange);

auto zmm_0x04 = _mm512_set1_epi8(0x04);
auto zmm_0x00 = _mm512_set1_epi8(0x00);
int elt_pad = utils::padto_le(unpack_elt, VElt);
size_t elt_pad = utils::padto_le(unpack_elt, VElt);
for (; i < elt_pad; i += VElt) {
auto vb1 = unpack_1bits(bit1ptr + i / 8, zmm_0x00, zmm_0x04);
vb1 = _mm512_srli_epi32(vb1, 2);
Expand Down Expand Up @@ -3351,15 +3352,15 @@ static inline BTLA_CODE decompress_s5_s8(utils::bit4x2* bit4ptr, utils::bit1x8*
size_t unpack_elt, int8_t* tmp, size_t tmpsize) {
int constexpr VBits = 512;
int constexpr VElt = VBits / 8;
int i = 0;
size_t i = 0;
int constexpr FullRange = 1 << (5 - 1);
uint32_t mask = 0x0f0f0f0f;
auto vmask = _mm512_set1_epi32(*reinterpret_cast<int*>(&mask));
auto vbias = _mm512_set1_epi8(FullRange);

auto zmm_0x04 = _mm512_set1_epi8(0x04);
auto zmm_0x00 = _mm512_set1_epi8(0x00);
int elt_pad = utils::padto_le(unpack_elt, VElt);
size_t elt_pad = utils::padto_le(unpack_elt, VElt);
for (; i < elt_pad; i += VElt) {
auto vout = unpack_4bits(bit4ptr + i / 2, vmask);
auto vb1 = unpack_1bits(bit1ptr + i / 8, zmm_0x00, zmm_0x04);
Expand Down Expand Up @@ -3617,7 +3618,7 @@ static inline BTLA_CODE decompress_s7_s8(utils::bit4x2* bit4ptr, utils::bit2x4*

auto zmm_0x04 = _mm512_set1_epi8(0x04);
auto zmm_0x00 = _mm512_set1_epi8(0x00);
int elt_pad = utils::padto_le(unpack_elt, VElt);
size_t elt_pad = utils::padto_le(unpack_elt, VElt);
for (; i < elt_pad; i += VElt) {
auto vout = unpack_4bits(bit4ptr + i / 2, vmask);
auto vb1 = unpack_1bits(bit1ptr + i / 8, zmm_0x00, zmm_0x04);
Expand Down Expand Up @@ -3929,7 +3930,7 @@ static inline BTLA_CODE decompress_s6_s8(utils::bit4x2* bit4ptr, utils::bit2x4*
13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0,
15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0);
auto vorder_y = _mm512_set_epi32(3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0);
int elt_pad = utils::padto_le(unpack_elt, VElt);
size_t elt_pad = utils::padto_le(unpack_elt, VElt);
for (; i < elt_pad; i += VElt) {
auto vout = unpack_4bits(bit4ptr + i / 2, vmask);
auto vb1 = unpack_2bits(bit2ptr + i / 4, vshift_y, vmask0, vsfhl_mask_y, vorder_y);
Expand Down
Loading
Loading