Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TL/UCP: Add support for active_set broadcast with knomial and dbt for size greater than 2 #926

Merged
merged 1 commit into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/components/tl/mlx5/tl_mlx5_context.c
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ ucc_status_t ucc_tl_mlx5_context_ib_ctx_pd_setup(ucc_base_context_t *context)
steam = core_ctx->service_team;
s.map = sbgp->map;
s.myrank = sbgp->group_rank;

status = UCC_TL_TEAM_IFACE(steam)->scoll.bcast(
&steam->super, sbcast_data, sbcast_data_length, PD_OWNER_RANK, s, &req);

Expand Down
59 changes: 41 additions & 18 deletions src/components/tl/ucp/bcast/bcast_dbt.c
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,12 @@ static void recv_completion_2(void *request, ucs_status_t status,

void ucc_tl_ucp_bcast_dbt_progress(ucc_coll_task_t *coll_task)
{
uint32_t i;
ucc_tl_ucp_task_t *task =
ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_coll_args_t *args = &TASK_ARGS(task);
ucc_rank_t rank = UCC_TL_TEAM_RANK(team);
ucc_rank_t rank = task->subset.myrank;
ucc_dbt_single_tree_t t1 = task->bcast_dbt.t1;
ucc_dbt_single_tree_t t2 = task->bcast_dbt.t2;
void *buffer = args->src.info.buffer;
Expand All @@ -87,20 +88,28 @@ void ucc_tl_ucp_bcast_dbt_progress(ucc_coll_task_t *coll_task)
ucc_rank_t coll_root = (ucc_rank_t)args->root;
ucp_tag_recv_nbx_callback_t cb[2] = {recv_completion_1,
recv_completion_2};
uint32_t i;

if (UCC_COLL_ARGS_ACTIVE_SET(&(TASK_ARGS(task)))) {
coll_root = ucc_ep_map_local_rank(task->subset.map, coll_root);
}

UCC_BCAST_DBT_GOTO_STATE(task->bcast_dbt.state);

if (rank != t1.root && rank != coll_root) {
UCPCHECK_GOTO(ucc_tl_ucp_recv_cb(buffer, data_size_t1, mtype,
t1.parent, team, task, cb[0],
ucc_ep_map_eval(task->subset.map,
t1.parent),
team, task, cb[0],
(void *)task),
task, out);
}

if (rank != t2.root && rank != coll_root) {
UCPCHECK_GOTO(ucc_tl_ucp_recv_cb(PTR_OFFSET(buffer, data_size_t1),
data_size_t2, mtype, t2.parent, team,
data_size_t2, mtype,
ucc_ep_map_eval(task->subset.map,
t2.parent),
team,
task, cb[1], (void *)task),
task, out);
}
Expand All @@ -114,7 +123,10 @@ void ucc_tl_ucp_bcast_dbt_progress(ucc_coll_task_t *coll_task)
if ((t1.children[i] != UCC_RANK_INVALID) &&
(t1.children[i] != coll_root)) {
UCPCHECK_GOTO(ucc_tl_ucp_send_nb(buffer, data_size_t1, mtype,
t1.children[i], team, task),
ucc_ep_map_eval(
task->subset.map,
t1.children[i]),
team, task),
task, out);
}
}
Expand All @@ -133,7 +145,10 @@ void ucc_tl_ucp_bcast_dbt_progress(ucc_coll_task_t *coll_task)
UCPCHECK_GOTO(ucc_tl_ucp_send_nb(PTR_OFFSET(buffer,
data_size_t1),
data_size_t2, mtype,
t2.children[i], team, task),
ucc_ep_map_eval(
task->subset.map,
t2.children[i]),
team, task),
task, out);
}
}
Expand Down Expand Up @@ -161,7 +176,7 @@ ucc_status_t ucc_tl_ucp_bcast_dbt_start(ucc_coll_task_t *coll_task)
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_coll_args_t *args = &TASK_ARGS(task);
ucc_status_t status = UCC_OK;
ucc_rank_t rank = UCC_TL_TEAM_RANK(team);
ucc_rank_t rank = task->subset.myrank;
void *buffer = args->src.info.buffer;
ucc_memory_type_t mtype = args->src.info.mem_type;
ucc_datatype_t dt = args->src.info.datatype;
Expand All @@ -170,7 +185,9 @@ ucc_status_t ucc_tl_ucp_bcast_dbt_start(ucc_coll_task_t *coll_task)
: count / 2;
size_t data_size_t1 = count_t1 * ucc_dt_size(dt);
size_t data_size_t2 = count / 2 * ucc_dt_size(dt);
ucc_rank_t coll_root = (ucc_rank_t)args->root;
ucc_rank_t coll_root = ucc_ep_map_local_rank(
task->subset.map,
(ucc_rank_t)args->root);
ucc_rank_t t1_root = task->bcast_dbt.t1.root;
ucc_rank_t t2_root = task->bcast_dbt.t2.root;
ucp_tag_recv_nbx_callback_t cb[2] = {recv_completion_1,
Expand All @@ -181,23 +198,28 @@ ucc_status_t ucc_tl_ucp_bcast_dbt_start(ucc_coll_task_t *coll_task)
ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);

if (rank == coll_root && coll_root != t1_root) {
status = ucc_tl_ucp_send_nb(buffer, data_size_t1, mtype, t1_root, team,
task);
status = ucc_tl_ucp_send_nb(buffer, data_size_t1, mtype,
ucc_ep_map_eval(task->subset.map, t1_root),
team, task);
if (UCC_OK != status) {
return status;
}
}

if (rank == coll_root && coll_root != t2_root) {
status = ucc_tl_ucp_send_nb(PTR_OFFSET(buffer, data_size_t1),
data_size_t2, mtype, t2_root, team, task);
data_size_t2, mtype,
ucc_ep_map_eval(task->subset.map, t2_root),
team, task);
if (UCC_OK != status) {
return status;
}
}

if (rank != coll_root && rank == t1_root) {
status = ucc_tl_ucp_recv_cb(buffer, data_size_t1, mtype, coll_root,
status = ucc_tl_ucp_recv_cb(buffer, data_size_t1, mtype,
ucc_ep_map_eval(task->subset.map,
coll_root),
team, task, cb[0], (void *)task);
if (UCC_OK != status) {
return status;
Expand All @@ -206,8 +228,10 @@ ucc_status_t ucc_tl_ucp_bcast_dbt_start(ucc_coll_task_t *coll_task)

if (rank != coll_root && rank == t2_root) {
status = ucc_tl_ucp_recv_cb(PTR_OFFSET(buffer, data_size_t1),
data_size_t2, mtype, coll_root, team, task,
cb[1], (void *)task);
data_size_t2, mtype,
ucc_ep_map_eval(task->subset.map,
coll_root),
team, task, cb[1], (void *)task);
if (UCC_OK != status) {
return status;
}
Expand All @@ -227,7 +251,6 @@ ucc_status_t ucc_tl_ucp_bcast_dbt_init(
ucc_base_coll_args_t *coll_args, ucc_base_team_t *team,
ucc_coll_task_t **task_h)
{
ucc_tl_ucp_team_t *tl_team;
ucc_tl_ucp_task_t *task;
ucc_rank_t rank, size;

Expand All @@ -236,9 +259,9 @@ ucc_status_t ucc_tl_ucp_bcast_dbt_init(
task->super.progress = ucc_tl_ucp_bcast_dbt_progress;
task->super.finalize = ucc_tl_ucp_bcast_dbt_finalize;
task->n_polls = ucc_max(1, task->n_polls);
tl_team = TASK_TEAM(task);
rank = UCC_TL_TEAM_RANK(tl_team);
size = UCC_TL_TEAM_SIZE(tl_team);
rank = task->subset.myrank;
size = (ucc_rank_t)task->subset.map.ep_num;

ucc_dbt_build_trees(rank, size, &task->bcast_dbt.t1,
&task->bcast_dbt.t2);

Expand Down
14 changes: 11 additions & 3 deletions src/components/tl/ucp/bcast/bcast_knomial.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,28 @@

void ucc_tl_ucp_bcast_knomial_progress(ucc_coll_task_t *coll_task)
{
uint32_t i;

ucc_rank_t vrank;
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_rank_t rank = task->subset.myrank;
ucc_rank_t size = (ucc_rank_t)task->subset.map.ep_num;
ucc_rank_t root = (uint32_t)TASK_ARGS(task).root;

uint32_t radix = task->bcast_kn.radix;
ucc_rank_t vrank = (rank - root + size) % size;
ucc_rank_t root = (uint32_t)TASK_ARGS(task).root;
ucc_rank_t dist = task->bcast_kn.dist;
void *buffer = TASK_ARGS(task).src.info.buffer;
ucc_memory_type_t mtype = TASK_ARGS(task).src.info.mem_type;
size_t data_size = TASK_ARGS(task).src.info.count *
ucc_dt_size(TASK_ARGS(task).src.info.datatype);
ucc_rank_t vpeer, peer, vroot_at_level, root_at_level, pos;
uint32_t i;

if (UCC_COLL_ARGS_ACTIVE_SET(&(TASK_ARGS(task)))) {
root = ucc_ep_map_local_rank(task->subset.map, root);
}

vrank = (rank - root + size) % size;

if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) {
return;
Expand Down
8 changes: 2 additions & 6 deletions src/components/tl/ucp/tl_ucp_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -364,20 +364,16 @@ ucc_tl_ucp_init_task(ucc_base_coll_args_t *coll_args, ucc_base_team_t *team)
ucc_coll_task_init(&task->super, coll_args, team);

if (UCC_COLL_ARGS_ACTIVE_SET(&coll_args->args)) {
task->tagged.tag = (coll_args->mask & UCC_COLL_ARGS_FIELD_TAG)
task->tagged.tag = (coll_args->args.mask & UCC_COLL_ARGS_FIELD_TAG)
nsarka marked this conversation as resolved.
Show resolved Hide resolved
? coll_args->args.tag : UCC_TL_UCP_ACTIVE_SET_TAG;
task->flags |= UCC_TL_UCP_TASK_FLAG_SUBSET;
task->subset.map = ucc_active_set_to_ep_map(&coll_args->args);
task->subset.myrank =
ucc_ep_map_local_rank(task->subset.map,
UCC_TL_TEAM_RANK(tl_team));
ucc_assert(coll_args->args.coll_type == UCC_COLL_TYPE_BCAST);
/* root value in args corresponds to the original team ranks,
need to convert to subset local value */
TASK_ARGS(task).root = ucc_ep_map_local_rank(task->subset.map,
coll_args->args.root);
} else {
if (coll_args->mask & UCC_COLL_ARGS_FIELD_TAG) {
if (coll_args->args.mask & UCC_COLL_ARGS_FIELD_TAG) {
task->tagged.tag = coll_args->args.tag;
} else {
tl_team->seq_num = (tl_team->seq_num + 1) % UCC_TL_UCP_MAX_COLL_TAG;
Expand Down
5 changes: 2 additions & 3 deletions src/core/ucc_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,8 @@ UCC_CORE_PROFILE_FUNC(ucc_status_t, ucc_collective_init,
}

if (UCC_COLL_ARGS_ACTIVE_SET(coll_args) &&
((UCC_COLL_TYPE_BCAST != coll_args->coll_type) ||
coll_args->active_set.size != 2)) {
ucc_warn("Active Sets are only supported for bcast and set size = 2");
(UCC_COLL_TYPE_BCAST != coll_args->coll_type)) {
ucc_warn("Active Sets are only supported for bcast");
return UCC_ERR_NOT_SUPPORTED;
}

Expand Down
7 changes: 4 additions & 3 deletions src/utils/ucc_coll_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ size_t ucc_coll_args_msgsize(const ucc_coll_args_t *args,
ucc_memory_type_t ucc_coll_args_mem_type(const ucc_coll_args_t *args,
ucc_rank_t rank);


/* Convert rank from subset space to rank space (UCC team space) */
static inline ucc_rank_t ucc_ep_map_eval(ucc_ep_map_t map, ucc_rank_t rank)
{
ucc_rank_t r;
Expand Down Expand Up @@ -261,8 +261,9 @@ static inline size_t ucc_buffer_block_offset(size_t total_count,

/* Given the rank space A (e.g. core ucc team), a subset B (e.g. active set
within the core team), the ep_map that maps ranks from the subset B to A,
and the rank of a process within A.
The function below computes the local rank of the process within subset B. */
and the rank of a process within A. The function below computes the local
rank of the process within subset B.
i.e., convert from rank space (UCC team) to subset space */
static inline ucc_rank_t ucc_ep_map_local_rank(ucc_ep_map_t map,
ucc_rank_t rank)
{
Expand Down
Loading
Loading