From b43108b32586d85d3dca9cb5f5ce76307268c2fb Mon Sep 17 00:00:00 2001 From: divya2108 Date: Thu, 3 Oct 2024 14:24:00 +0530 Subject: [PATCH] Optimised code design and handled ci test failures --- cmake/CheckSVEsupport.cmake | 2 +- src/common/hist_util.cc | 70 ++++++++++++++++++++++--------------- 2 files changed, 43 insertions(+), 29 deletions(-) diff --git a/cmake/CheckSVEsupport.cmake b/cmake/CheckSVEsupport.cmake index f7d92ad5f679..3abc19e6b1b2 100644 --- a/cmake/CheckSVEsupport.cmake +++ b/cmake/CheckSVEsupport.cmake @@ -1,7 +1,7 @@ function(check_xgboost_sve_support) if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64") include(CheckCSourceCompiles) - + # Save the original C_FLAGS to restore later set(ORIGINAL_C_FLAGS "${CMAKE_C_FLAGS}") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8-a+sve") diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index 6ee073523f9d..1986a7e4277f 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -18,6 +18,7 @@ #ifdef __linux__ #include +#define PR_SVE_GET_VL 51 #endif #ifdef XGBOOST_SVE_COMPILER_SUPPORT @@ -201,7 +202,8 @@ inline void UpdateHistogramWithSVE(size_t row_size, const BinIdxType *gr_index_l for (size_t j = 0; j < row_size; j += svcntw()) { // Create a predicate (mask) for 32-bit & 64-bit elements, active only for valid elements svbool_t pg32 = svwhilelt_b32(j, row_size); - svbool_t pg64 = svwhilelt_b64(j, row_size); + svbool_t pg64_lower = svwhilelt_b64(j, row_size); + svbool_t pg64_upper = svwhilelt_b64(j+svcntd(), row_size); // Load the gradient index values and offsets for the current chunk of the row svuint32_t gr_index_vec = @@ -216,42 +218,53 @@ inline void UpdateHistogramWithSVE(size_t row_size, const BinIdxType *gr_index_l idx_bin_vec = svmul_n_u32_x(pg32, temp, two); } - // Unpack the 32-bit index binary vector into 64-bit vectors from lower and upper half - // respectively + // Unpack 32-bit index binary vector into 64-bit vectors from lower & upper half respectively svuint64_t idx_bin_vec0_0 = svunpklo_u64(idx_bin_vec); svuint64_t idx_bin_vec0_1 = svunpkhi_u64(idx_bin_vec); // Increment the indices by 1 for hessian. - svuint64_t idx_bin_vec1_0 = svadd_n_u64_m(pg64, idx_bin_vec0_0, 1); - svuint64_t idx_bin_vec1_1 = svadd_n_u64_m(pg64, idx_bin_vec0_1, 1); + svuint64_t idx_bin_vec1_0 = svadd_n_u64_m(pg64_lower, idx_bin_vec0_0, 1); + svuint64_t idx_bin_vec1_1 = svadd_n_u64_m(pg64_upper, idx_bin_vec0_1, 1); // Gather the histogram data corresponding to the computed indices - svfloat64_t hist0_vec0 = svld1_gather_index(pg64, hist_data, idx_bin_vec0_0); - svfloat64_t hist0_vec1 = svld1_gather_index(pg64, hist_data, idx_bin_vec0_1); - svfloat64_t hist1_vec0 = svld1_gather_index(pg64, hist_data, idx_bin_vec1_0); - svfloat64_t hist1_vec1 = svld1_gather_index(pg64, hist_data, idx_bin_vec1_1); + svfloat64_t hist0_vec0 = svld1_gather_index(pg64_lower, hist_data, idx_bin_vec0_0); + svfloat64_t hist0_vec1 = svld1_gather_index(pg64_upper, hist_data, idx_bin_vec0_1); + svfloat64_t hist1_vec0 = svld1_gather_index(pg64_lower, hist_data, idx_bin_vec1_0); + svfloat64_t hist1_vec1 = svld1_gather_index(pg64_upper, hist_data, idx_bin_vec1_1); // Accumulate the gradient and hessian values into the histogram - hist0_vec0 = svadd_f64_m(pg64, hist0_vec0, grad); - hist0_vec1 = svadd_f64_m(pg64, hist0_vec1, grad); - hist1_vec0 = svadd_f64_m(pg64, hist1_vec0, hess); - hist1_vec1 = svadd_f64_m(pg64, hist1_vec1, hess); + hist0_vec0 = svadd_f64_m(pg64_lower, hist0_vec0, grad); + hist0_vec1 = svadd_f64_m(pg64_upper, hist0_vec1, grad); + hist1_vec0 = svadd_f64_m(pg64_lower, hist1_vec0, hess); + hist1_vec1 = svadd_f64_m(pg64_upper, hist1_vec1, hess); // Store the updated histogram data back into memory - svst1_scatter_index(pg64, hist_data, idx_bin_vec0_0, hist0_vec0); - svst1_scatter_index(pg64, hist_data, idx_bin_vec0_1, hist0_vec1); - svst1_scatter_index(pg64, hist_data, idx_bin_vec1_0, hist1_vec0); - svst1_scatter_index(pg64, hist_data, idx_bin_vec1_1, hist1_vec1); + svst1_scatter_index(pg64_lower, hist_data, idx_bin_vec0_0, hist0_vec0); + svst1_scatter_index(pg64_upper, hist_data, idx_bin_vec0_1, hist0_vec1); + svst1_scatter_index(pg64_lower, hist_data, idx_bin_vec1_0, hist1_vec0); + svst1_scatter_index(pg64_upper, hist_data, idx_bin_vec1_1, hist1_vec1); } } #endif -// Returns true if SVE ISA is available on the current CPU -bool check_sve_hw_support() { - int ret = prctl(PR_SVE_GET_VL); - return ret >= 0 ? 1 : 0; +// Returns true if SVE ISA is available on the current CPU (with caching) +#ifdef __linux__ +int check_sve_hw_support() { + static int cached_sve_support = -1; + if (cached_sve_support == -1) { + int ret = prctl(PR_SVE_GET_VL); + if (ret == -1) { + cached_sve_support = 0; + } else { + cached_sve_support = 1; + } + } + return cached_sve_support; } +static int sve_enabled = check_sve_hw_support(); +#endif + template void RowsWiseBuildHistKernel(Span gpair, Span row_indices, const GHistIndexMatrix &gmat, GHistRow hist) { @@ -289,7 +302,6 @@ void RowsWiseBuildHistKernel(Span gpair, Span gpair, Span gpair, Span