diff --git a/src/components/tl/ucp/allgather/allgather_sparbit.c b/src/components/tl/ucp/allgather/allgather_sparbit.c index c4d7fd152e..265e691bd4 100644 --- a/src/components/tl/ucp/allgather/allgather_sparbit.c +++ b/src/components/tl/ucp/allgather/allgather_sparbit.c @@ -41,6 +41,15 @@ ucc_status_t ucc_tl_ucp_allgather_sparbit_init(ucc_base_coll_args_t *coll_args, return status; } +static inline uint32_t highest_power2(uint32_t n) +{ + // if n is a power of two simply return it + if (!(n & (n - 1))) + return n; + // else set only the most significant bit + return 0x80000000 >> (__builtin_clz(n)); // number of leading zeros +} + /* Inspired by implementation: https://github.com/open-mpi/ompi/blob/main/ompi/mca/coll/base/coll_base_allgather.c */ void ucc_tl_ucp_allgather_sparbit_progress(ucc_coll_task_t *coll_task) { @@ -53,28 +62,27 @@ void ucc_tl_ucp_allgather_sparbit_progress(ucc_coll_task_t *coll_task) ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype; size_t count = TASK_ARGS(task).dst.info.count; size_t data_size = (count / tsize) * ucc_dt_size(dt); - ucc_rank_t recvfrom, sendto; - size_t distance; - uint32_t last_ignore; - uint32_t ignore_steps; - uint32_t i = task->allgather_sparbit.i; // restore iteration number - int tsize_log, exclusion, data_expected, transfer_count; - void *tmprecv, *tmpsend; - - // here we can't made any progress while transfers from previous step are running, emulation of wait all in asyn manier + uint32_t i = task->allgather_sparbit.i; // restore iteration number + ucc_rank_t recvfrom, sendto; + size_t distance; + uint32_t last_ignore, ignore_steps, data_expected, transfer_count; + uint32_t tsize_log, exclusion; + void *tmprecv, *tmpsend; + + // here we can't made any progress while transfers from previous step are running, emulation of wait all in async manier if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) { return; } - tsize_log = ceil(log(tsize) / log(2)); - last_ignore = __builtin_ctz(tsize); + tsize_log = highest_power2(tsize); + last_ignore = __builtin_ctz(tsize); // count trailing zeros ignore_steps = (~((uint32_t)tsize >> last_ignore) | 1) << last_ignore; while (i < tsize_log) { data_expected = task->allgather_sparbit.data_expected; distance = 1 << (tsize_log - 1); - distance >>= i; // restore distance in case of continuation + distance >>= i; // restore distance in case of continuation depending on step recvfrom = (trank + tsize - distance) % tsize; sendto = (trank + distance) % tsize; @@ -100,19 +108,14 @@ void ucc_tl_ucp_allgather_sparbit_progress(ucc_coll_task_t *coll_task) task->allgather_sparbit.data_expected = (data_expected << 1) - exclusion; task->allgather_sparbit.i++; - // wait for completion of all tasks to check if we could make one more step right now or we should yeld task + // check if we could make one more step right now or we should yeld task if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) { return; } i = task->allgather_sparbit.i; } - if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) { - return; - } - ucc_assert(UCC_TL_UCP_TASK_P2P_COMPLETE(task)); - task->super.status = UCC_OK; out: