Skip to content

Commit

Permalink
TL/UCP: Convert sliding window to schedule-based
Browse files Browse the repository at this point in the history
  • Loading branch information
nsarka committed Jun 4, 2024
1 parent b7e9b90 commit 8f3ed31
Show file tree
Hide file tree
Showing 9 changed files with 412 additions and 293 deletions.
2 changes: 2 additions & 0 deletions src/components/tl/ucp/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ sources = \
tl_ucp_ep.c \
tl_ucp_coll.c \
tl_ucp_service_coll.c \
tl_ucp_dpu_offload.h \
tl_ucp_dpu_offload.c \
$(allgather) \
$(allgatherv) \
$(alltoall) \
Expand Down
119 changes: 104 additions & 15 deletions src/components/tl/ucp/allreduce/allreduce.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
#include "tl_ucp.h"
#include "allreduce.h"
#include "utils/ucc_coll_utils.h"
#include "tl_ucp_dpu_offload.h"
#include "../allgather/allgather.h"

ucc_status_t
ucc_tl_ucp_allreduce_sliding_window_alloc_pipe(ucc_base_team_t *team,
ucc_tl_ucp_task_t *task);

ucc_base_coll_alg_info_t
ucc_tl_ucp_allreduce_algs[UCC_TL_UCP_ALLREDUCE_ALG_LAST + 1] = {
Expand Down Expand Up @@ -61,34 +67,117 @@ ucc_tl_ucp_allreduce_sliding_window_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h)
{
ucc_tl_ucp_task_t *task;
ucc_status_t status = UCC_OK;
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
ucc_schedule_t *schedule = NULL;
ucc_status_t status = UCC_OK;
ucc_tl_ucp_team_t *tl_team =
ucc_derived_of(team, ucc_tl_ucp_team_t);
size_t allgather_size =
sizeof(ucc_tl_ucp_allreduce_sw_host_allgather_t);
ucc_rank_t size = team->params.size;
ucc_base_coll_args_t bargs = {
.mask = 0,
.args = {
.coll_type = UCC_COLL_TYPE_ALLGATHER,
.mask = 0,
.src.info = {.buffer = NULL,
.count = allgather_size,
.datatype = UCC_DT_UINT8,
.mem_type = UCC_MEMORY_TYPE_HOST},
.dst.info = {.buffer = NULL,
.count = allgather_size * size,
.datatype = UCC_DT_UINT8,
.mem_type = UCC_MEMORY_TYPE_HOST}
}
};
ucc_base_coll_args_t barrier_coll_args = {
.team = team->params.team,
.args.coll_type = UCC_COLL_TYPE_BARRIER,
};
ucc_tl_ucp_allreduce_sw_host_allgather_t *allgather_data;
ucc_tl_ucp_task_t *rdma_task;
ucc_coll_task_t *barrier_task;

ALLREDUCE_TASK_CHECK(coll_args->args, tl_team);
status = ucc_tl_ucp_get_schedule(tl_team, coll_args,
(ucc_tl_ucp_schedule_t **)&schedule);
if (ucc_unlikely(UCC_OK != status)) {
return status;
}

*task_h = &schedule->super;
schedule->super.post = ucc_tl_ucp_allreduce_sliding_window_start;
schedule->super.progress = NULL;
schedule->super.finalize = ucc_tl_ucp_allreduce_sliding_window_finalize;

schedule->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR;

task = ucc_tl_ucp_init_task(coll_args, team);
if (ucc_unlikely(!task)) {
if (UCC_OK != status) {
ucc_error("failed to init executor: %s", ucc_status_string(status));
}

rdma_task = ucc_tl_ucp_init_task(coll_args, team);
if (ucc_unlikely(!rdma_task)) {
ucc_error("couldnt allocate task");
return UCC_ERR_NO_MEMORY;
}
*task_h = &task->super;
task->super.post = ucc_tl_ucp_allreduce_sliding_window_start;
task->super.progress = ucc_tl_ucp_allreduce_sliding_window_progress;
task->super.finalize = ucc_tl_ucp_allreduce_sliding_window_finalize;

status = ucc_tl_ucp_allreduce_sliding_window_task_init(coll_args, team, task);
if (ucc_tl_ucp_allreduce_sliding_window_alloc_pipe(team, rdma_task) != UCC_OK) {
ucc_error("failed to alloc pipe: %s", ucc_status_string(status));
goto free_rdma_task;
}

status = ucc_tl_ucp_allreduce_sliding_window_task_init(coll_args, team,
rdma_task);
if (status != UCC_OK) {
ucc_tl_ucp_put_task(task);
ucc_error("failed to init task: %s", ucc_status_string(status));
goto out;
}

task->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR;
allgather_data = rdma_task->allreduce_sliding_window.allgather_data;
bargs.args.src.info.buffer = allgather_data;
bargs.args.dst.info.buffer = PTR_OFFSET(allgather_data, allgather_size);

if (UCC_OK != status) {
ucc_error("failed to init executor: %s", ucc_status_string(status));
rdma_task->super.post = ucc_tl_ucp_allreduce_sliding_window_rdma_task_post;
rdma_task->super.progress = ucc_tl_ucp_allreduce_sliding_window_rdma_progress;
rdma_task->super.finalize = ucc_tl_ucp_allreduce_sliding_window_rdma_task_finalize;

UCC_CHECK_GOTO(ucc_tl_ucp_allgather_ring_init(&bargs, team,
&rdma_task->allreduce_sliding_window.allgather_task),
free_rdma_pipe, status);

status = ucc_tl_ucp_coll_init(&barrier_coll_args, team,
&barrier_task);
if (status < 0) {
tl_error(team->context->lib,
"failure during sliding window barrier init: %s",
ucc_status_string(status));
goto free_allgather_task;
}

UCC_CHECK_GOTO(ucc_schedule_add_task(schedule, &rdma_task->super), out, status);
UCC_CHECK_GOTO(ucc_event_manager_subscribe(&schedule->super,
UCC_EVENT_SCHEDULE_STARTED,
&rdma_task->super,
ucc_task_start_handler),
free_barrier_task, status);

UCC_CHECK_GOTO(ucc_schedule_add_task(schedule, barrier_task), out, status);
UCC_CHECK_GOTO(ucc_event_manager_subscribe(&rdma_task->super,
UCC_EVENT_COMPLETED,
barrier_task,
ucc_task_start_handler),
free_barrier_task, status);

return status;

free_barrier_task:
ucc_tl_ucp_coll_finalize(barrier_task);
free_allgather_task:
ucc_tl_ucp_coll_finalize(rdma_task->allreduce_sliding_window.allgather_task);
free_rdma_pipe:
ucc_tl_ucp_allreduce_sliding_window_free_pipe(&rdma_task->super);
free_rdma_task:
ucc_tl_ucp_allreduce_sliding_window_free_task(&rdma_task->super);
out:
ucc_tl_ucp_put_schedule(schedule);
return status;
}
8 changes: 7 additions & 1 deletion src/components/tl/ucp/allreduce/allreduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,17 @@ void ucc_tl_ucp_allreduce_knomial_progress(ucc_coll_task_t *task);
ucc_status_t
ucc_tl_ucp_allreduce_sliding_window_start(ucc_coll_task_t *coll_task);

void ucc_tl_ucp_allreduce_sliding_window_progress(ucc_coll_task_t *task);
void ucc_tl_ucp_allreduce_sliding_window_rdma_progress(ucc_coll_task_t *task);

ucc_status_t
ucc_tl_ucp_allreduce_sliding_window_finalize(ucc_coll_task_t *task);

ucc_status_t
ucc_tl_ucp_allreduce_sliding_window_rdma_task_finalize(ucc_coll_task_t *coll_task);

ucc_status_t
ucc_tl_ucp_allreduce_sliding_window_rdma_task_post(ucc_coll_task_t *coll_task);

ucc_status_t ucc_tl_ucp_allreduce_knomial_finalize(ucc_coll_task_t *task);

ucc_status_t ucc_tl_ucp_allreduce_sra_knomial_init(ucc_base_coll_args_t *coll_args,
Expand Down
Loading

0 comments on commit 8f3ed31

Please sign in to comment.