diff --git a/src/components/tl/ucp/allgather/allgather_ring.c b/src/components/tl/ucp/allgather/allgather_ring.c index 93d7b95fc4..5a5cd1827c 100644 --- a/src/components/tl/ucp/allgather/allgather_ring.c +++ b/src/components/tl/ucp/allgather/allgather_ring.c @@ -43,6 +43,7 @@ void ucc_tl_ucp_allgather_ring_progress(ucc_coll_task_t *coll_task) ucc_rank_t sendto, recvfrom, sblock, rblock; int step; void *buf; + ucc_status_t status; if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) { return; @@ -69,7 +70,12 @@ void ucc_tl_ucp_allgather_ring_progress(ucc_coll_task_t *coll_task) } } ucc_assert(UCC_TL_UCP_TASK_P2P_COMPLETE(task)); - task->super.status = UCC_OK; + status = ucc_ee_executor_task_test(task->allgather_ring.etask); + if (status == UCC_INPROGRESS) { + return; + } + ucc_ee_executor_task_finalize(task->allgather_ring.etask); + task->super.status = status; out: UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_ring_done", 0); } @@ -88,22 +94,49 @@ ucc_status_t ucc_tl_ucp_allgather_ring_start(ucc_coll_task_t *coll_task) ucc_rank_t tsize = (ucc_rank_t)task->subset.map.ep_num; size_t data_size = (count / tsize) * ucc_dt_size(dt); ucc_status_t status; - ucc_rank_t block; + ucc_rank_t sendto, recvfrom, sblock, rblock; + ucc_ee_executor_t *exec; + ucc_ee_executor_task_args_t eargs; + void *buf; UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_ring_start", 0); ucc_tl_ucp_task_reset(task, UCC_INPROGRESS); + sendto = ucc_ep_map_eval(task->subset.map, (trank + 1) % tsize); + recvfrom = ucc_ep_map_eval(task->subset.map, (trank - 1 + tsize) % tsize); + sblock = task->allgather_ring.get_send_block(&task->subset, trank, tsize, 0); + rblock = task->allgather_ring.get_recv_block(&task->subset, trank, tsize, 0); if (!UCC_IS_INPLACE(TASK_ARGS(task))) { - block = task->allgather_ring.get_send_block(&task->subset, trank, tsize, - 0); - status = ucc_mc_memcpy(PTR_OFFSET(rbuf, data_size * block), - sbuf, data_size, rmem, smem); - if (ucc_unlikely(UCC_OK != status)) { + status = ucc_coll_task_get_executor(&task->super, &exec); + if (ucc_unlikely(status != UCC_OK)) { return status; } + + eargs.task_type = UCC_EE_EXECUTOR_TASK_COPY; + eargs.copy.src = sbuf; + eargs.copy.dst = PTR_OFFSET(rbuf, data_size * sblock); + eargs.copy.len = data_size; + + status = ucc_ee_executor_task_post(exec, &eargs, + &task->allgather_ring.etask); + if (ucc_unlikely(status != UCC_OK)) { + return status; + } + buf = sbuf; + } else { + buf = PTR_OFFSET(rbuf, data_size * sblock); } + UCPCHECK_GOTO(ucc_tl_ucp_send_nb(buf, data_size, smem, sendto, team, task), + task, out); + UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(PTR_OFFSET(rbuf, rblock * data_size), + data_size, rmem, recvfrom, team, task), + task, out); + return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); + +out: + return status; } ucc_status_t ucc_tl_ucp_allgather_ring_init_common(ucc_tl_ucp_task_t *task) @@ -128,6 +161,9 @@ ucc_status_t ucc_tl_ucp_allgather_ring_init_common(ucc_tl_ucp_task_t *task) task->allgather_ring.get_recv_block = ucc_tl_ucp_allgather_ring_get_recv_block; task->super.post = ucc_tl_ucp_allgather_ring_start; task->super.progress = ucc_tl_ucp_allgather_ring_progress; + if (!UCC_IS_INPLACE(TASK_ARGS(task))) { + task->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR; + } return UCC_OK; } diff --git a/src/components/tl/ucp/tl_ucp_coll.h b/src/components/tl/ucp/tl_ucp_coll.h index c2a260b103..97a492a1ab 100644 --- a/src/components/tl/ucp/tl_ucp_coll.h +++ b/src/components/tl/ucp/tl_ucp_coll.h @@ -178,6 +178,7 @@ typedef struct ucc_tl_ucp_task { ucc_rank_t trank, ucc_rank_t tsize, int step); + ucc_ee_executor_task_t *etask; } allgather_ring; struct { ucc_rank_t dist; diff --git a/src/components/tl/ucp/tl_ucp_service_coll.c b/src/components/tl/ucp/tl_ucp_service_coll.c index bf16cf00d7..2406d90e58 100644 --- a/src/components/tl/ucp/tl_ucp_service_coll.c +++ b/src/components/tl/ucp/tl_ucp_service_coll.c @@ -178,6 +178,14 @@ ucc_status_t ucc_tl_ucp_service_allgather(ucc_base_team_t *team, void *sbuf, task->n_polls = npolls; task->super.progress = ucc_tl_ucp_allgather_ring_progress; task->super.finalize = ucc_tl_ucp_coll_finalize; + if (in_place) { + task->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR; + } + + status = ucc_tl_ucp_service_coll_start_executor(&task->super); + if (status != UCC_OK) { + goto free_task; + } status = ucc_tl_ucp_allgather_ring_start(&task->super); if (status != UCC_OK) { @@ -187,7 +195,8 @@ ucc_status_t ucc_tl_ucp_service_allgather(ucc_base_team_t *team, void *sbuf, *task_p = &task->super; return status; finalize_coll: - ucc_tl_ucp_coll_finalize(*task_p); + ucc_tl_ucp_coll_finalize(&task->super); + ucc_tl_ucp_service_coll_stop_executor(&task->super); free_task: ucc_tl_ucp_put_task(task); return status;