Skip to content

Commit

Permalink
TL/UCP: add allreduce dbt
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergei-Lebedev committed Jan 16, 2024
1 parent 27a4c9f commit b3c260c
Show file tree
Hide file tree
Showing 8 changed files with 158 additions and 8 deletions.
3 changes: 2 additions & 1 deletion src/components/tl/ucp/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ allreduce = \
allreduce/allreduce.h \
allreduce/allreduce.c \
allreduce/allreduce_knomial.c \
allreduce/allreduce_sra_knomial.c
allreduce/allreduce_sra_knomial.c \
allreduce/allreduce_dbt.c

barrier = \
barrier/barrier.h \
Expand Down
5 changes: 5 additions & 0 deletions src/components/tl/ucp/allreduce/allreduce.c
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ ucc_base_coll_alg_info_t
.name = "sra_knomial",
.desc = "recursive knomial scatter-reduce followed by knomial "
"allgather (optimized for BW)"},
[UCC_TL_UCP_ALLREDUCE_ALG_DBT] =
{.id = UCC_TL_UCP_ALLREDUCE_ALG_SRA_KNOMIAL,
.name = "dbt",
.desc = "alreduce over double binary tree where a leaf in one tree "
"will be intermediate in other (optimized for BW)"},
[UCC_TL_UCP_ALLREDUCE_ALG_LAST] = {
.id = 0, .name = NULL, .desc = NULL}};

Expand Down
17 changes: 13 additions & 4 deletions src/components/tl/ucp/allreduce/allreduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
enum {
UCC_TL_UCP_ALLREDUCE_ALG_KNOMIAL,
UCC_TL_UCP_ALLREDUCE_ALG_SRA_KNOMIAL,
UCC_TL_UCP_ALLREDUCE_ALG_DBT,
UCC_TL_UCP_ALLREDUCE_ALG_LAST
};

Expand All @@ -36,8 +37,8 @@ ucc_status_t ucc_tl_ucp_allreduce_init(ucc_tl_ucp_task_t *task);
CHECK_SAME_MEMTYPE((_args), (_team));

ucc_status_t ucc_tl_ucp_allreduce_knomial_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t * team,
ucc_coll_task_t ** task_h);
ucc_base_team_t *team,
ucc_coll_task_t **task_h);

ucc_status_t ucc_tl_ucp_allreduce_knomial_init_common(ucc_tl_ucp_task_t *task);

Expand All @@ -48,13 +49,21 @@ void ucc_tl_ucp_allreduce_knomial_progress(ucc_coll_task_t *task);
ucc_status_t ucc_tl_ucp_allreduce_knomial_finalize(ucc_coll_task_t *task);

ucc_status_t ucc_tl_ucp_allreduce_sra_knomial_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t * team,
ucc_coll_task_t ** task_h);
ucc_base_team_t *team,
ucc_coll_task_t **task_h);

ucc_status_t ucc_tl_ucp_allreduce_sra_knomial_start(ucc_coll_task_t *task);

ucc_status_t ucc_tl_ucp_allreduce_sra_knomial_progress(ucc_coll_task_t *task);

ucc_status_t ucc_tl_ucp_allreduce_dbt_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h);

ucc_status_t ucc_tl_ucp_allreduce_dbt_start(ucc_coll_task_t *task);

ucc_status_t ucc_tl_ucp_allreduce_dbt_progress(ucc_coll_task_t *task);

static inline int ucc_tl_ucp_allreduce_alg_from_str(const char *str)
{
int i;
Expand Down
94 changes: 94 additions & 0 deletions src/components/tl/ucp/allreduce/allreduce_dbt.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/**
* Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/

#include "config.h"
#include "tl_ucp.h"
#include "allreduce.h"
#include "../reduce/reduce.h"
#include "../bcast/bcast.h"

ucc_status_t ucc_tl_ucp_allreduce_dbt_start(ucc_coll_task_t *coll_task)
{
ucc_schedule_t *schedule = ucc_derived_of(coll_task, ucc_schedule_t);
ucc_coll_args_t *args = &schedule->super.bargs.args;
ucc_coll_task_t *reduce_task, *bcast_task;

reduce_task = schedule->tasks[0];
reduce_task->bargs.args.src.info.buffer = args->src.info.buffer;
reduce_task->bargs.args.dst.info.buffer = args->dst.info.buffer;
reduce_task->bargs.args.src.info.count = args->src.info.count;
reduce_task->bargs.args.dst.info.count = args->dst.info.count;

bcast_task = schedule->tasks[1];
bcast_task->bargs.args.src.info.buffer = args->dst.info.buffer;
bcast_task->bargs.args.src.info.count = args->dst.info.count;

UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allreduce_dbt_start", 0);
return ucc_schedule_start(coll_task);
}

ucc_status_t ucc_tl_ucp_allreduce_dbt_finalize(ucc_coll_task_t *coll_task)
{
ucc_schedule_t *schedule = ucc_derived_of(coll_task, ucc_schedule_t);
ucc_status_t status;

UCC_TL_UCP_PROFILE_REQUEST_EVENT(schedule, "ucp_allreduce_dbt_done", 0);
status = ucc_schedule_finalize(coll_task);
ucc_tl_ucp_put_schedule(schedule);
return status;
}

ucc_status_t ucc_tl_ucp_allreduce_dbt_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h)
{
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
ucc_base_coll_args_t args = *coll_args;
ucc_schedule_t *schedule;
ucc_coll_task_t *reduce_task, *bcast_task;
ucc_status_t status;

if (UCC_IS_INPLACE(args.args)) {
return UCC_ERR_NOT_SUPPORTED;
}

status = ucc_tl_ucp_get_schedule(tl_team, coll_args,
(ucc_tl_ucp_schedule_t **)&schedule);
if (ucc_unlikely(UCC_OK != status)) {
return status;
}

args.args.root = 0;
UCC_CHECK_GOTO(ucc_tl_ucp_reduce_dbt_init(&args, team, &reduce_task),
out, status);
UCC_CHECK_GOTO(ucc_schedule_add_task(schedule, reduce_task),
out, status);
UCC_CHECK_GOTO(ucc_event_manager_subscribe(&schedule->super,
UCC_EVENT_SCHEDULE_STARTED,
reduce_task,
ucc_task_start_handler),
out, status);

UCC_CHECK_GOTO(ucc_tl_ucp_bcast_dbt_init(&args, team, &bcast_task),
out, status);
UCC_CHECK_GOTO(ucc_schedule_add_task(schedule, bcast_task),
out, status);
UCC_CHECK_GOTO(ucc_event_manager_subscribe(reduce_task, UCC_EVENT_COMPLETED,
bcast_task,
ucc_task_start_handler),
out, status);

schedule->super.post = ucc_tl_ucp_allreduce_dbt_start;
schedule->super.progress = NULL;
schedule->super.finalize = ucc_tl_ucp_allreduce_dbt_finalize;
*task_h = &schedule->super;

return UCC_OK;

out:
ucc_tl_ucp_put_schedule(schedule);
return status;
}
4 changes: 2 additions & 2 deletions src/components/tl/ucp/bcast/bcast_sag_knomial.c
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ ucc_tl_ucp_bcast_sag_knomial_finalize(ucc_coll_task_t *coll_task)

ucc_status_t
ucc_tl_ucp_bcast_sag_knomial_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h)
ucc_base_team_t *team,
ucc_coll_task_t **task_h)
{
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
size_t count = coll_args->args.src.info.count;
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/ucp/reduce/reduce_dbt.c
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ void ucc_tl_ucp_reduce_dbt_progress(ucc_coll_task_t *coll_task)
for (i = 0; i < 2; i++) {
if (is_root && rank == trees[i].root) {
UCPCHECK_GOTO(ucc_mc_memcpy(PTR_OFFSET(args->dst.info.buffer,
i * counts[i] * ucc_dt_size(dt)),
i * counts[i - 1] * ucc_dt_size(dt)),
rbuf[i], counts[i] * ucc_dt_size(dt),
mtype, mtype), task, out);
}
Expand Down
4 changes: 4 additions & 0 deletions src/components/tl/ucp/tl_ucp_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ ucc_status_t ucc_tl_ucp_alg_id_to_init(int alg_id, const char *alg_id_str,
ucc_base_coll_init_fn_t *init)
{
ucc_status_t status = UCC_OK;

if (alg_id_str) {
alg_id = alg_id_from_str(coll_type, alg_id_str);
}
Expand Down Expand Up @@ -274,6 +275,9 @@ ucc_status_t ucc_tl_ucp_alg_id_to_init(int alg_id, const char *alg_id_str,
case UCC_TL_UCP_ALLREDUCE_ALG_SRA_KNOMIAL:
*init = ucc_tl_ucp_allreduce_sra_knomial_init;
break;
case UCC_TL_UCP_ALLREDUCE_ALG_DBT:
*init = ucc_tl_ucp_allreduce_dbt_init;
break;
default:
status = UCC_ERR_INVALID_PARAM;
break;
Expand Down
37 changes: 37 additions & 0 deletions test/gtest/coll/test_allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,43 @@ TYPED_TEST(test_allreduce_alg, sra_knomial_pipelined) {
}
}

TYPED_TEST(test_allreduce_alg, dbt) {
int n_procs = 15;
ucc_job_env_t env = {{"UCC_CL_BASIC_TUNE", "inf"},
{"UCC_TL_UCP_TUNE", "allreduce:@dbt:inf"}};
UccJob job(n_procs, UccJob::UCC_JOB_CTX_GLOBAL, env);
UccTeam_h team = job.create_team(n_procs);
int repeat = 3;
UccCollCtxVec ctxs;
std::vector<ucc_memory_type_t> mt = {UCC_MEMORY_TYPE_HOST};

if (UCC_OK == ucc_mc_available(UCC_MEMORY_TYPE_CUDA)) {
mt.push_back(UCC_MEMORY_TYPE_CUDA);
}
if (UCC_OK == ucc_mc_available( UCC_MEMORY_TYPE_CUDA_MANAGED)) {
mt.push_back( UCC_MEMORY_TYPE_CUDA_MANAGED);
}

for (auto count : {65536, 123567}) {
for (auto inplace : {TEST_NO_INPLACE, TEST_INPLACE}) {
for (auto m : mt) {
SET_MEM_TYPE(m);
this->set_inplace(inplace);
this->data_init(n_procs, TypeParam::dt, count, ctxs, true);
UccReq req(team, ctxs);

for (auto i = 0; i < repeat; i++) {
req.start();
req.wait();
EXPECT_EQ(true, this->data_validate(ctxs));
this->reset(ctxs);
}
this->data_fini(ctxs);
}
}
}
}

TYPED_TEST(test_allreduce_alg, rab) {
int n_procs = 15;
ucc_job_env_t env = {{"UCC_CL_HIER_TUNE", "allreduce:@rab:0-inf:inf"},
Expand Down

0 comments on commit b3c260c

Please sign in to comment.