diff --git a/CMakeLists.txt b/CMakeLists.txt index 6e6207055347..fd1349c66c3e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -265,6 +265,51 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "OS400") set(CMAKE_CXX_ARCHIVE_CREATE " -X64 qc ") endif() +if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64") + include(CheckCSourceCompiles) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8-a+sve") + check_c_source_compiles(" + #if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) + #include + int main() { + svfloat64_t a; + a = svdup_n_f64(0); + return 0; + } + #endif + " COMPILER_HAS_ARM_SVE) + + if(COMPILER_HAS_ARM_SVE) + message(STATUS "ARM SVE compiler support detected") + set(SOURCE_CODE " + #include + int main() { + int ret = prctl(PR_SVE_GET_VL); + return ret >= 0 ? 0 : 1; + } + ") + file(WRITE ${CMAKE_BINARY_DIR}/check_sve_support.c "${SOURCE_CODE}") + try_run(RUN_RESULT COMPILE_RESULT + ${CMAKE_BINARY_DIR}/check_sve_support_output + ${CMAKE_BINARY_DIR}/check_sve_support.c + ) + + if(RUN_RESULT EQUAL 0) + message(STATUS "ARM SVE hardware support detected") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8-a+sve") + string(APPEND CMAKE_CXX_FLAGS " -DSVE_SUPPORT_DETECTED") + else() + message(STATUS "ARM SVE hardware support not detected") + endif() + else() + message(STATUS "ARM SVE compiler support not detected") + endif() + + set(CMAKE_C_FLAGS "${ORIGINAL_CMAKE_C_FLAGS}") +else() + message(STATUS "Not an aarch64 architecture") +endif() + if(USE_NCCL) find_package(Nccl REQUIRED) endif() diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index dfd80cb68c13..a01bd675f4d4 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -1,5 +1,6 @@ /** * Copyright 2017-2023 by XGBoost Contributors + * Copyright 2024 FUJITSU LIMITED * \file hist_util.cc */ #include "hist_util.h" @@ -15,6 +16,10 @@ #include "xgboost/context.h" // for Context #include "xgboost/data.h" // for SparsePage, SortedCSCPage +#if defined(SVE_SUPPORT_DETECTED) +#include // to leverage sve intrinsics +#endif + #if defined(XGBOOST_MM_PREFETCH_PRESENT) #include #define PREFETCH_READ_T0(addr) _mm_prefetch(reinterpret_cast(addr), _MM_HINT_T0) @@ -252,13 +257,52 @@ void RowsWiseBuildHistKernel(Span gpair, Span(gr_index_local[j]) + (kAnyMissing ? 0 : offsets[j])); - auto hist_local = hist_data + idx_bin; - *(hist_local) += pgh_t[0]; - *(hist_local + 1) += pgh_t[1]; - } + #if defined(SVE_SUPPORT_DETECTED) + svfloat64_t pgh_t0_vec = svdup_n_f64(pgh_t[0]); + svfloat64_t pgh_t1_vec = svdup_n_f64(pgh_t[1]); + + for (size_t j = 0; j < row_size; j+=svcntw()) { + svbool_t pg32 = svwhilelt_b32(j, row_size); + svbool_t pg64 = svwhilelt_b64(j, row_size); + svuint32_t gr_index_vec = + svld1ub_u32(pg32, reinterpret_cast (&gr_index_local[j])); + svuint32_t offsets_vec = svld1(pg32, &offsets[j]); + svuint32_t idx_bin_vec; + if (kAnyMissing) { + idx_bin_vec = svmul_n_u32_x(pg32, gr_index_vec, two); + } else { + svuint32_t temp = svadd_u32_m(pg32, gr_index_vec, offsets_vec); + idx_bin_vec = svmul_n_u32_x(pg32, temp, two); + } + svuint64_t idx_bin_vec0_0 = svunpklo_u64(idx_bin_vec); + svuint64_t idx_bin_vec0_1 = svunpkhi_u64(idx_bin_vec); + svuint64_t idx_bin_vec1_0 = svadd_n_u64_m(pg64, idx_bin_vec0_0, 1); + svuint64_t idx_bin_vec1_1 = svadd_n_u64_m(pg64, idx_bin_vec0_1, 1); + + svfloat64_t hist0_vec0 = svld1_gather_index(pg64, hist_data, idx_bin_vec0_0); + svfloat64_t hist0_vec1 = svld1_gather_index(pg64, hist_data, idx_bin_vec0_1); + svfloat64_t hist1_vec0 = svld1_gather_index(pg64, hist_data, idx_bin_vec1_0); + svfloat64_t hist1_vec1 = svld1_gather_index(pg64, hist_data, idx_bin_vec1_1); + + hist0_vec0 = svadd_f64_m(pg64, hist0_vec0, pgh_t0_vec); + hist0_vec1 = svadd_f64_m(pg64, hist0_vec1, pgh_t0_vec); + hist1_vec0 = svadd_f64_m(pg64, hist1_vec0, pgh_t1_vec); + hist1_vec1 = svadd_f64_m(pg64, hist1_vec1, pgh_t1_vec); + + svst1_scatter_index(pg64, hist_data, idx_bin_vec0_0, hist0_vec0); + svst1_scatter_index(pg64, hist_data, idx_bin_vec0_1, hist0_vec1); + svst1_scatter_index(pg64, hist_data, idx_bin_vec1_0, hist1_vec0); + svst1_scatter_index(pg64, hist_data, idx_bin_vec1_1, hist1_vec1); + } + #else + for (size_t j = 0; j < row_size; ++j) { + const uint32_t idx_bin = + two * (static_cast(gr_index_local[j]) + (kAnyMissing ? 0 : offsets[j])); + auto hist_local = hist_data + idx_bin; + *(hist_local) += pgh_t[0]; + *(hist_local + 1) += pgh_t[1]; + } + #endif } }