diff --git a/src/components/tl/ucp/allgather/allgather_bruck.c b/src/components/tl/ucp/allgather/allgather_bruck.c index 053e3f06f0..a8ef9017d7 100644 --- a/src/components/tl/ucp/allgather/allgather_bruck.c +++ b/src/components/tl/ucp/allgather/allgather_bruck.c @@ -22,6 +22,7 @@ ucc_status_t ucc_tl_ucp_allgather_bruck_init(ucc_base_coll_args_t *coll_args, ucc_status_t status = UCC_OK; ucc_rank_t trank = UCC_TL_TEAM_RANK(tl_team); ucc_rank_t tsize = UCC_TL_TEAM_SIZE(tl_team); + ucc_memory_type_t rmem = TASK_ARGS(task).dst.info.mem_type; ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype; size_t count = TASK_ARGS(task).dst.info.count; size_t data_size = (count / tsize) * ucc_dt_size(dt); @@ -38,8 +39,11 @@ ucc_status_t ucc_tl_ucp_allgather_bruck_init(ucc_base_coll_args_t *coll_args, task->super.progress = ucc_tl_ucp_allgather_bruck_progress; task->super.finalize = ucc_tl_ucp_allgather_bruck_finalize; + /* allocate scratch buffer only on non root rank */ if (trank != 0) { - /* allocate scratch buffer only on non root rank */ + if (UCC_MEMORY_TYPE_HOST != rmem) { + scratch_size = tsize * data_size; + } status = ucc_mc_alloc(&task->allgather_bruck.scratch_header, scratch_size, UCC_MEMORY_TYPE_HOST); if (ucc_unlikely(status != UCC_OK)) { @@ -147,27 +151,60 @@ void ucc_tl_ucp_allgather_bruck_progress(ucc_coll_task_t *coll_task) /* post processing step */ if (trank != 0) { - // copy blocks [0 .. (size - rank - 1)] from rbuf to shift buffer - status = ucc_mc_memcpy(scratch_header->addr, rbuf, scratch_size, - UCC_MEMORY_TYPE_HOST, rmem); - if (ucc_unlikely(status != UCC_OK)) { - tl_error(UCC_TASK_LIB(task), - "failed to copy data to scratch buffer"); - ucc_tl_ucp_coll_finalize(&task->super); - return; - } - // move blocks [(size - rank) .. size] from rbuf to beginning of rbuf - // TODO: rewrite to cycle to get rid of overlap - memmove(rbuf, PTR_OFFSET(rbuf, scratch_size), trank * data_size); - // copy blocks from shift buffer starting at block [rank] in rbuf. - status = ucc_mc_memcpy(PTR_OFFSET(rbuf, trank * data_size), - scratch_header->addr, scratch_size, rmem, - UCC_MEMORY_TYPE_HOST); - if (ucc_unlikely(status != UCC_OK)) { - tl_error(UCC_TASK_LIB(task), - "failed to copy data from scratch to rbuff buffer"); - ucc_tl_ucp_coll_finalize(&task->super); - return; + if (UCC_MEMORY_TYPE_HOST == rmem) { + // copy blocks [0 .. (size - rank - 1)] from rbuf to shift buffer + status = ucc_mc_memcpy(scratch_header->addr, rbuf, scratch_size, + UCC_MEMORY_TYPE_HOST, rmem); + if (ucc_unlikely(status != UCC_OK)) { + tl_error(UCC_TASK_LIB(task), + "failed to copy data to scratch buffer"); + ucc_tl_ucp_coll_finalize(&task->super); + return; + } + // move blocks [(size - rank) .. size] from rbuf to beginning of rbuf + // TODO: rewrite to cycle to get rid of overlap + memmove(rbuf, PTR_OFFSET(rbuf, scratch_size), trank * data_size); + // copy blocks from shift buffer starting at block [rank] in rbuf. + status = ucc_mc_memcpy(PTR_OFFSET(rbuf, trank * data_size), + scratch_header->addr, scratch_size, rmem, + UCC_MEMORY_TYPE_HOST); + if (ucc_unlikely(status != UCC_OK)) { + tl_error(UCC_TASK_LIB(task), + "failed to copy data from scratch to rbuff buffer"); + ucc_tl_ucp_coll_finalize(&task->super); + return; + } + } else { + /* In case of non host memory we perform two copy to host buffer and then back to device, 3 memcopy in total */ + /* TODO: replace with generic kernel to do bruck post step in sinle launch on device */ + status = ucc_mc_memcpy( + PTR_OFFSET(scratch_header->addr, trank * data_size), rbuf, + (tsize - trank) * data_size, UCC_MEMORY_TYPE_HOST, rmem); + if (ucc_unlikely(status != UCC_OK)) { + tl_error(UCC_TASK_LIB(task), + "failed to copy first data part to scratch buffer"); + ucc_tl_ucp_coll_finalize(&task->super); + return; + } + status = + ucc_mc_memcpy(scratch_header->addr, + PTR_OFFSET(rbuf, (tsize - trank) * data_size), + trank * data_size, UCC_MEMORY_TYPE_HOST, rmem); + if (ucc_unlikely(status != UCC_OK)) { + tl_error(UCC_TASK_LIB(task), + "failed to copy second data part to scratch buffer"); + ucc_tl_ucp_coll_finalize(&task->super); + return; + } + status = + ucc_mc_memcpy(rbuf, scratch_header->addr, tsize * data_size, + rmem, UCC_MEMORY_TYPE_HOST); + if (ucc_unlikely(status != UCC_OK)) { + tl_error(UCC_TASK_LIB(task), + "failed to copy from scratch buffer to dst"); + ucc_tl_ucp_coll_finalize(&task->super); + return; + } } }