diff --git a/src/components/tl/ucp/allgather/allgather.h b/src/components/tl/ucp/allgather/allgather.h index 0e816e8838..ac3592df86 100644 --- a/src/components/tl/ucp/allgather/allgather.h +++ b/src/components/tl/ucp/allgather/allgather.h @@ -66,6 +66,8 @@ void ucc_tl_ucp_allgather_bruck_progress(ucc_coll_task_t *task); ucc_status_t ucc_tl_ucp_allgather_bruck_start(ucc_coll_task_t *task); +ucc_status_t ucc_tl_ucp_allgather_bruck_finalize(ucc_coll_task_t *coll_task); + /* Uses allgather_kn_radix from config */ ucc_status_t ucc_tl_ucp_allgather_knomial_init(ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, diff --git a/src/components/tl/ucp/allgather/allgather_bruck.c b/src/components/tl/ucp/allgather/allgather_bruck.c index 052f12802c..0e1d17f353 100644 --- a/src/components/tl/ucp/allgather/allgather_bruck.c +++ b/src/components/tl/ucp/allgather/allgather_bruck.c @@ -17,27 +17,36 @@ ucc_status_t ucc_tl_ucp_allgather_bruck_init(ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, ucc_coll_task_t **task_h) { - ucc_status_t status = UCC_OK; - ucc_tl_ucp_task_t *task; - ucc_tl_ucp_team_t *ucp_team; - - task = ucc_tl_ucp_init_task(coll_args, team); - ucp_team = TASK_TEAM(task); + ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t); + ucc_tl_ucp_task_t *task = ucc_tl_ucp_init_task(coll_args, team); + ucc_status_t status = UCC_OK; + ucc_rank_t trank = UCC_TL_TEAM_RANK(tl_team); + ucc_rank_t tsize = UCC_TL_TEAM_SIZE(tl_team); + ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype; + size_t count = TASK_ARGS(task).dst.info.count; + size_t data_size = (count / tsize) * ucc_dt_size(dt); + size_t scratch_size = (tsize - trank) * data_size; if (!ucc_coll_args_is_predefined_dt(&TASK_ARGS(task), UCC_RANK_INVALID)) { tl_error(UCC_TASK_LIB(task), "user defined datatype is not supported"); status = UCC_ERR_NOT_SUPPORTED; goto out; } - printf("ucc_tl_ucp_allgather_bruck_init\n"); - if (UCC_TL_TEAM_SIZE(ucp_team) % 2) { - tl_debug(UCC_TASK_LIB(task), - "odd team size is not supported, switching to ring"); - status = ucc_tl_ucp_allgather_ring_init_common(task); - } else { - task->super.post = ucc_tl_ucp_allgather_bruck_start; - task->super.progress = ucc_tl_ucp_allgather_bruck_progress; + tl_trace(UCC_TASK_LIB(task), "ucc_tl_ucp_allgather_bruck_init"); + + task->super.post = ucc_tl_ucp_allgather_bruck_start; + task->super.progress = ucc_tl_ucp_allgather_bruck_progress; + task->super.finalize = ucc_tl_ucp_allgather_bruck_finalize; + + /* allocate scratch buffer */ + status = ucc_mc_alloc(&task->allgather_bruck.scratch_header, scratch_size, + UCC_MEMORY_TYPE_HOST); + if (ucc_unlikely(status != UCC_OK)) { + tl_error(UCC_TASK_LIB(task), "failed to allocate scratch buffer"); + ucc_tl_ucp_coll_finalize(&task->super); + goto out; } + task->allgather_bruck.scratch_size = scratch_size; out: if (status != UCC_OK) { @@ -49,22 +58,48 @@ ucc_status_t ucc_tl_ucp_allgather_bruck_init(ucc_base_coll_args_t *coll_args, return status; } +ucc_status_t ucc_tl_ucp_allgather_bruck_finalize(ucc_coll_task_t *coll_task) +{ + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + ucc_status_t status, global_status; + + tl_trace(UCC_TASK_LIB(task), "ucc_tl_ucp_allgather_bruck_finalize"); + + /* deallocate scratch buffer */ + global_status = ucc_mc_free(task->allgather_bruck.scratch_header); + if (ucc_unlikely(global_status != UCC_OK)) { + tl_error(UCC_TASK_LIB(task), "failed to free scratch buffer memory"); + } + task->allgather_bruck.scratch_size = 0; + + status = ucc_tl_ucp_coll_finalize(&task->super); + if (ucc_unlikely(status != UCC_OK)) { + tl_error(UCC_TASK_LIB(task), + "failed to finalize allgather bruck collective"); + global_status = status; + } + return global_status; +} + /* Inspired by implementation: https://github.com/open-mpi/ompi/blob/main/ompi/mca/coll/base/coll_base_allgather.c */ void ucc_tl_ucp_allgather_bruck_progress(ucc_coll_task_t *coll_task) { - ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); - ucc_tl_ucp_team_t *team = TASK_TEAM(task); - ucc_rank_t trank = UCC_TL_TEAM_RANK(team); - ucc_rank_t tsize = UCC_TL_TEAM_SIZE(team); - void *rbuf = TASK_ARGS(task).dst.info.buffer; - ucc_memory_type_t rmem = TASK_ARGS(task).dst.info.mem_type; - ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype; - size_t count = TASK_ARGS(task).dst.info.count; - size_t data_size = (count / tsize) * ucc_dt_size(dt); - ucc_rank_t recvfrom, sendto; - ucc_status_t status; - size_t blockcount, distance; - void *tmprecv, *tmpsend; + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + ucc_tl_ucp_team_t *team = TASK_TEAM(task); + ucc_rank_t trank = UCC_TL_TEAM_RANK(team); + ucc_rank_t tsize = UCC_TL_TEAM_SIZE(team); + void *rbuf = TASK_ARGS(task).dst.info.buffer; + ucc_memory_type_t rmem = TASK_ARGS(task).dst.info.mem_type; + ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype; + size_t count = TASK_ARGS(task).dst.info.count; + ucc_mc_buffer_header_t *scratch_header = + task->allgather_bruck.scratch_header; + size_t scratch_size = task->allgather_bruck.scratch_size; + size_t data_size = (count / tsize) * ucc_dt_size(dt); + ucc_rank_t recvfrom, sendto; + ucc_status_t status; + size_t blockcount, distance; + void *tmprecv, *tmpsend; if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) { return; @@ -72,9 +107,8 @@ void ucc_tl_ucp_allgather_bruck_progress(ucc_coll_task_t *coll_task) /* On each step doubles distance */ distance = 1 << task->tagged.send_posted; - printf("bruck\n"); - tmpsend = rbuf; - while (distance < (tsize)) { + tmpsend = rbuf; + while (distance < tsize) { recvfrom = (trank + distance) % tsize; sendto = (trank + tsize - distance) % tsize; @@ -99,21 +133,12 @@ void ucc_tl_ucp_allgather_bruck_progress(ucc_coll_task_t *coll_task) if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) { return; } + + distance = 1 << task->tagged.send_posted; } /* post processing step */ if (trank != 0) { - ucc_mc_buffer_header_t *scratch_header; - size_t scratch_size = (tsize - trank) * data_size; - /* allocate scratch buffer */ - status = - ucc_mc_alloc(&scratch_header, scratch_size, UCC_MEMORY_TYPE_HOST); - if (ucc_unlikely(status != UCC_OK)) { - tl_error(UCC_TASK_LIB(task), "failed to allocate scratch buffer"); - ucc_tl_ucp_coll_finalize(&task->super); - return; - } - status = ucc_mc_memcpy(scratch_header->addr, rbuf, scratch_size, UCC_MEMORY_TYPE_HOST, UCC_MEMORY_TYPE_HOST); if (ucc_unlikely(status != UCC_OK)) { @@ -142,18 +167,10 @@ void ucc_tl_ucp_allgather_bruck_progress(ucc_coll_task_t *coll_task) ucc_tl_ucp_coll_finalize(&task->super); return; } - - /* deallocate scratch buffer */ - status = ucc_mc_free(scratch_header); - if (ucc_unlikely(status != UCC_OK)) { - tl_error(UCC_TASK_LIB(task), - "failed to free scratch buffer memory"); - ucc_tl_ucp_coll_finalize(&task->super); - return; - } } ucc_assert(UCC_TL_UCP_TASK_P2P_COMPLETE(task)); + task->super.status = UCC_OK; out: @@ -180,12 +197,13 @@ ucc_status_t ucc_tl_ucp_allgather_bruck_start(ucc_coll_task_t *coll_task) /* initial step: copy data on non root ranks to the beginning of buffer */ if (!UCC_IS_INPLACE(TASK_ARGS(task))) { - status = ucc_mc_memcpy(rbuf, PTR_OFFSET(sbuf, data_size * trank), - data_size, rmem, smem); + // not inplace: copy chunk from source buff to beginning of receive + status = ucc_mc_memcpy(rbuf, sbuf, data_size, rmem, smem); if (ucc_unlikely(UCC_OK != status)) { return status; } } else if (trank != 0) { + // inplace: copy chunk to the begin status = ucc_mc_memcpy(rbuf, PTR_OFFSET(rbuf, data_size * trank), data_size, rmem, rmem); if (ucc_unlikely(UCC_OK != status)) { diff --git a/src/components/tl/ucp/tl_ucp_coll.h b/src/components/tl/ucp/tl_ucp_coll.h index 6ab2c661dd..fc64ed2d0e 100644 --- a/src/components/tl/ucp/tl_ucp_coll.h +++ b/src/components/tl/ucp/tl_ucp_coll.h @@ -181,6 +181,10 @@ typedef struct ucc_tl_ucp_task { ucc_rank_t tsize, int step); } allgather_ring; + struct { + ucc_mc_buffer_header_t *scratch_header; + size_t scratch_size; + } allgather_bruck; struct { ucc_rank_t dist; uint32_t radix;