Skip to content

Commit

Permalink
TL/UCP: make local copy nb in allgather
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergei-Lebedev committed Nov 6, 2023
1 parent 8a7b494 commit b7fd093
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 8 deletions.
50 changes: 43 additions & 7 deletions src/components/tl/ucp/allgather/allgather_ring.c
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ void ucc_tl_ucp_allgather_ring_progress(ucc_coll_task_t *coll_task)
ucc_rank_t sendto, recvfrom, sblock, rblock;
int step;
void *buf;
ucc_status_t status;

if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) {
return;
Expand All @@ -69,7 +70,12 @@ void ucc_tl_ucp_allgather_ring_progress(ucc_coll_task_t *coll_task)
}
}
ucc_assert(UCC_TL_UCP_TASK_P2P_COMPLETE(task));
task->super.status = UCC_OK;
status = ucc_ee_executor_task_test(task->allgather_ring.etask);
if (status == UCC_INPROGRESS) {
return;
}
ucc_ee_executor_task_finalize(task->allgather_ring.etask);
task->super.status = status;
out:
UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_ring_done", 0);
}
Expand All @@ -88,22 +94,49 @@ ucc_status_t ucc_tl_ucp_allgather_ring_start(ucc_coll_task_t *coll_task)
ucc_rank_t tsize = (ucc_rank_t)task->subset.map.ep_num;
size_t data_size = (count / tsize) * ucc_dt_size(dt);
ucc_status_t status;
ucc_rank_t block;
ucc_rank_t sendto, recvfrom, sblock, rblock;
ucc_ee_executor_t *exec;
ucc_ee_executor_task_args_t eargs;
void *buf;

UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_ring_start", 0);
ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);

sendto = ucc_ep_map_eval(task->subset.map, (trank + 1) % tsize);
recvfrom = ucc_ep_map_eval(task->subset.map, (trank - 1 + tsize) % tsize);
sblock = task->allgather_ring.get_send_block(&task->subset, trank, tsize, 0);
rblock = task->allgather_ring.get_recv_block(&task->subset, trank, tsize, 0);
if (!UCC_IS_INPLACE(TASK_ARGS(task))) {
block = task->allgather_ring.get_send_block(&task->subset, trank, tsize,
0);
status = ucc_mc_memcpy(PTR_OFFSET(rbuf, data_size * block),
sbuf, data_size, rmem, smem);
if (ucc_unlikely(UCC_OK != status)) {
status = ucc_coll_task_get_executor(&task->super, &exec);
if (ucc_unlikely(status != UCC_OK)) {
return status;
}

eargs.task_type = UCC_EE_EXECUTOR_TASK_COPY;
eargs.copy.src = sbuf;
eargs.copy.dst = PTR_OFFSET(rbuf, data_size * sblock);
eargs.copy.len = data_size;

status = ucc_ee_executor_task_post(exec, &eargs,
&task->allgather_ring.etask);
if (ucc_unlikely(status != UCC_OK)) {
return status;
}
buf = sbuf;
} else {
buf = PTR_OFFSET(rbuf, data_size * sblock);
}

UCPCHECK_GOTO(ucc_tl_ucp_send_nb(buf, data_size, smem, sendto, team, task),
task, out);
UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(PTR_OFFSET(rbuf, rblock * data_size),
data_size, rmem, recvfrom, team, task),
task, out);

return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);

out:
return status;
}

ucc_status_t ucc_tl_ucp_allgather_ring_init_common(ucc_tl_ucp_task_t *task)
Expand All @@ -128,6 +161,9 @@ ucc_status_t ucc_tl_ucp_allgather_ring_init_common(ucc_tl_ucp_task_t *task)
task->allgather_ring.get_recv_block = ucc_tl_ucp_allgather_ring_get_recv_block;
task->super.post = ucc_tl_ucp_allgather_ring_start;
task->super.progress = ucc_tl_ucp_allgather_ring_progress;
if (!UCC_IS_INPLACE(TASK_ARGS(task))) {
task->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR;
}

return UCC_OK;
}
Expand Down
1 change: 1 addition & 0 deletions src/components/tl/ucp/tl_ucp_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ typedef struct ucc_tl_ucp_task {
ucc_rank_t trank,
ucc_rank_t tsize,
int step);
ucc_ee_executor_task_t *etask;
} allgather_ring;
struct {
ucc_rank_t dist;
Expand Down
11 changes: 10 additions & 1 deletion src/components/tl/ucp/tl_ucp_service_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,14 @@ ucc_status_t ucc_tl_ucp_service_allgather(ucc_base_team_t *team, void *sbuf,
task->n_polls = npolls;
task->super.progress = ucc_tl_ucp_allgather_ring_progress;
task->super.finalize = ucc_tl_ucp_coll_finalize;
if (in_place) {
task->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR;
}

status = ucc_tl_ucp_service_coll_start_executor(&task->super);
if (status != UCC_OK) {
goto free_task;
}

status = ucc_tl_ucp_allgather_ring_start(&task->super);
if (status != UCC_OK) {
Expand All @@ -187,7 +195,8 @@ ucc_status_t ucc_tl_ucp_service_allgather(ucc_base_team_t *team, void *sbuf,
*task_p = &task->super;
return status;
finalize_coll:
ucc_tl_ucp_coll_finalize(*task_p);
ucc_tl_ucp_coll_finalize(&task->super);
ucc_tl_ucp_service_coll_stop_executor(&task->super);
free_task:
ucc_tl_ucp_put_task(task);
return status;
Expand Down

0 comments on commit b7fd093

Please sign in to comment.