Skip to content

Commit

Permalink
TL/UCP: bcast active_set size greater than 2
Browse files Browse the repository at this point in the history
  • Loading branch information
nsarka authored and Nicholas Sarkauskas committed May 22, 2024
1 parent 7e5dbfd commit b122da3
Show file tree
Hide file tree
Showing 7 changed files with 191 additions and 103 deletions.
1 change: 1 addition & 0 deletions src/components/tl/mlx5/tl_mlx5_context.c
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ ucc_status_t ucc_tl_mlx5_context_ib_ctx_pd_setup(ucc_base_context_t *context)
steam = core_ctx->service_team;
s.map = sbgp->map;
s.myrank = sbgp->group_rank;

status = UCC_TL_TEAM_IFACE(steam)->scoll.bcast(
&steam->super, sbcast_data, sbcast_data_length, PD_OWNER_RANK, s, &req);

Expand Down
59 changes: 41 additions & 18 deletions src/components/tl/ucp/bcast/bcast_dbt.c
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,12 @@ static void recv_completion_2(void *request, ucs_status_t status,

void ucc_tl_ucp_bcast_dbt_progress(ucc_coll_task_t *coll_task)
{
uint32_t i;
ucc_tl_ucp_task_t *task =
ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_coll_args_t *args = &TASK_ARGS(task);
ucc_rank_t rank = UCC_TL_TEAM_RANK(team);
ucc_rank_t rank = task->subset.myrank;
ucc_dbt_single_tree_t t1 = task->bcast_dbt.t1;
ucc_dbt_single_tree_t t2 = task->bcast_dbt.t2;
void *buffer = args->src.info.buffer;
Expand All @@ -87,20 +88,28 @@ void ucc_tl_ucp_bcast_dbt_progress(ucc_coll_task_t *coll_task)
ucc_rank_t coll_root = (ucc_rank_t)args->root;
ucp_tag_recv_nbx_callback_t cb[2] = {recv_completion_1,
recv_completion_2};
uint32_t i;

if (UCC_COLL_ARGS_ACTIVE_SET(&(TASK_ARGS(task)))) {
coll_root = ucc_ep_map_local_rank(task->subset.map, coll_root);
}

UCC_BCAST_DBT_GOTO_STATE(task->bcast_dbt.state);

if (rank != t1.root && rank != coll_root) {
UCPCHECK_GOTO(ucc_tl_ucp_recv_cb(buffer, data_size_t1, mtype,
t1.parent, team, task, cb[0],
ucc_ep_map_eval(task->subset.map,
t1.parent),
team, task, cb[0],
(void *)task),
task, out);
}

if (rank != t2.root && rank != coll_root) {
UCPCHECK_GOTO(ucc_tl_ucp_recv_cb(PTR_OFFSET(buffer, data_size_t1),
data_size_t2, mtype, t2.parent, team,
data_size_t2, mtype,
ucc_ep_map_eval(task->subset.map,
t2.parent),
team,
task, cb[1], (void *)task),
task, out);
}
Expand All @@ -114,7 +123,10 @@ void ucc_tl_ucp_bcast_dbt_progress(ucc_coll_task_t *coll_task)
if ((t1.children[i] != UCC_RANK_INVALID) &&
(t1.children[i] != coll_root)) {
UCPCHECK_GOTO(ucc_tl_ucp_send_nb(buffer, data_size_t1, mtype,
t1.children[i], team, task),
ucc_ep_map_eval(
task->subset.map,
t1.children[i]),
team, task),
task, out);
}
}
Expand All @@ -133,7 +145,10 @@ void ucc_tl_ucp_bcast_dbt_progress(ucc_coll_task_t *coll_task)
UCPCHECK_GOTO(ucc_tl_ucp_send_nb(PTR_OFFSET(buffer,
data_size_t1),
data_size_t2, mtype,
t2.children[i], team, task),
ucc_ep_map_eval(
task->subset.map,
t2.children[i]),
team, task),
task, out);
}
}
Expand Down Expand Up @@ -161,7 +176,7 @@ ucc_status_t ucc_tl_ucp_bcast_dbt_start(ucc_coll_task_t *coll_task)
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_coll_args_t *args = &TASK_ARGS(task);
ucc_status_t status = UCC_OK;
ucc_rank_t rank = UCC_TL_TEAM_RANK(team);
ucc_rank_t rank = task->subset.myrank;
void *buffer = args->src.info.buffer;
ucc_memory_type_t mtype = args->src.info.mem_type;
ucc_datatype_t dt = args->src.info.datatype;
Expand All @@ -170,7 +185,9 @@ ucc_status_t ucc_tl_ucp_bcast_dbt_start(ucc_coll_task_t *coll_task)
: count / 2;
size_t data_size_t1 = count_t1 * ucc_dt_size(dt);
size_t data_size_t2 = count / 2 * ucc_dt_size(dt);
ucc_rank_t coll_root = (ucc_rank_t)args->root;
ucc_rank_t coll_root = ucc_ep_map_local_rank(
task->subset.map,
(ucc_rank_t)args->root);
ucc_rank_t t1_root = task->bcast_dbt.t1.root;
ucc_rank_t t2_root = task->bcast_dbt.t2.root;
ucp_tag_recv_nbx_callback_t cb[2] = {recv_completion_1,
Expand All @@ -181,23 +198,28 @@ ucc_status_t ucc_tl_ucp_bcast_dbt_start(ucc_coll_task_t *coll_task)
ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);

if (rank == coll_root && coll_root != t1_root) {
status = ucc_tl_ucp_send_nb(buffer, data_size_t1, mtype, t1_root, team,
task);
status = ucc_tl_ucp_send_nb(buffer, data_size_t1, mtype,
ucc_ep_map_eval(task->subset.map, t1_root),
team, task);
if (UCC_OK != status) {
return status;
}
}

if (rank == coll_root && coll_root != t2_root) {
status = ucc_tl_ucp_send_nb(PTR_OFFSET(buffer, data_size_t1),
data_size_t2, mtype, t2_root, team, task);
data_size_t2, mtype,
ucc_ep_map_eval(task->subset.map, t2_root),
team, task);
if (UCC_OK != status) {
return status;
}
}

if (rank != coll_root && rank == t1_root) {
status = ucc_tl_ucp_recv_cb(buffer, data_size_t1, mtype, coll_root,
status = ucc_tl_ucp_recv_cb(buffer, data_size_t1, mtype,
ucc_ep_map_eval(task->subset.map,
coll_root),
team, task, cb[0], (void *)task);
if (UCC_OK != status) {
return status;
Expand All @@ -206,8 +228,10 @@ ucc_status_t ucc_tl_ucp_bcast_dbt_start(ucc_coll_task_t *coll_task)

if (rank != coll_root && rank == t2_root) {
status = ucc_tl_ucp_recv_cb(PTR_OFFSET(buffer, data_size_t1),
data_size_t2, mtype, coll_root, team, task,
cb[1], (void *)task);
data_size_t2, mtype,
ucc_ep_map_eval(task->subset.map,
coll_root),
team, task, cb[1], (void *)task);
if (UCC_OK != status) {
return status;
}
Expand All @@ -227,7 +251,6 @@ ucc_status_t ucc_tl_ucp_bcast_dbt_init(
ucc_base_coll_args_t *coll_args, ucc_base_team_t *team,
ucc_coll_task_t **task_h)
{
ucc_tl_ucp_team_t *tl_team;
ucc_tl_ucp_task_t *task;
ucc_rank_t rank, size;

Expand All @@ -236,9 +259,9 @@ ucc_status_t ucc_tl_ucp_bcast_dbt_init(
task->super.progress = ucc_tl_ucp_bcast_dbt_progress;
task->super.finalize = ucc_tl_ucp_bcast_dbt_finalize;
task->n_polls = ucc_max(1, task->n_polls);
tl_team = TASK_TEAM(task);
rank = UCC_TL_TEAM_RANK(tl_team);
size = UCC_TL_TEAM_SIZE(tl_team);
rank = task->subset.myrank;
size = (ucc_rank_t)task->subset.map.ep_num;

ucc_dbt_build_trees(rank, size, &task->bcast_dbt.t1,
&task->bcast_dbt.t2);

Expand Down
14 changes: 11 additions & 3 deletions src/components/tl/ucp/bcast/bcast_knomial.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,28 @@

void ucc_tl_ucp_bcast_knomial_progress(ucc_coll_task_t *coll_task)
{
uint32_t i;

ucc_rank_t vrank;
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_rank_t rank = task->subset.myrank;
ucc_rank_t size = (ucc_rank_t)task->subset.map.ep_num;
ucc_rank_t root = (uint32_t)TASK_ARGS(task).root;

uint32_t radix = task->bcast_kn.radix;
ucc_rank_t vrank = (rank - root + size) % size;
ucc_rank_t root = (uint32_t)TASK_ARGS(task).root;
ucc_rank_t dist = task->bcast_kn.dist;
void *buffer = TASK_ARGS(task).src.info.buffer;
ucc_memory_type_t mtype = TASK_ARGS(task).src.info.mem_type;
size_t data_size = TASK_ARGS(task).src.info.count *
ucc_dt_size(TASK_ARGS(task).src.info.datatype);
ucc_rank_t vpeer, peer, vroot_at_level, root_at_level, pos;
uint32_t i;

if (UCC_COLL_ARGS_ACTIVE_SET(&(TASK_ARGS(task)))) {
root = ucc_ep_map_local_rank(task->subset.map, root);
}

vrank = (rank - root + size) % size;

if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) {
return;
Expand Down
8 changes: 2 additions & 6 deletions src/components/tl/ucp/tl_ucp_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -364,20 +364,16 @@ ucc_tl_ucp_init_task(ucc_base_coll_args_t *coll_args, ucc_base_team_t *team)
ucc_coll_task_init(&task->super, coll_args, team);

if (UCC_COLL_ARGS_ACTIVE_SET(&coll_args->args)) {
task->tagged.tag = (coll_args->mask & UCC_COLL_ARGS_FIELD_TAG)
task->tagged.tag = (coll_args->args.mask & UCC_COLL_ARGS_FIELD_TAG)
? coll_args->args.tag : UCC_TL_UCP_ACTIVE_SET_TAG;
task->flags |= UCC_TL_UCP_TASK_FLAG_SUBSET;
task->subset.map = ucc_active_set_to_ep_map(&coll_args->args);
task->subset.myrank =
ucc_ep_map_local_rank(task->subset.map,
UCC_TL_TEAM_RANK(tl_team));
ucc_assert(coll_args->args.coll_type == UCC_COLL_TYPE_BCAST);
/* root value in args corresponds to the original team ranks,
need to convert to subset local value */
TASK_ARGS(task).root = ucc_ep_map_local_rank(task->subset.map,
coll_args->args.root);
} else {
if (coll_args->mask & UCC_COLL_ARGS_FIELD_TAG) {
if (coll_args->args.mask & UCC_COLL_ARGS_FIELD_TAG) {
task->tagged.tag = coll_args->args.tag;
} else {
tl_team->seq_num = (tl_team->seq_num + 1) % UCC_TL_UCP_MAX_COLL_TAG;
Expand Down
5 changes: 2 additions & 3 deletions src/core/ucc_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,8 @@ UCC_CORE_PROFILE_FUNC(ucc_status_t, ucc_collective_init,
}

if (UCC_COLL_ARGS_ACTIVE_SET(coll_args) &&
((UCC_COLL_TYPE_BCAST != coll_args->coll_type) ||
coll_args->active_set.size != 2)) {
ucc_warn("Active Sets are only supported for bcast and set size = 2");
(UCC_COLL_TYPE_BCAST != coll_args->coll_type)) {
ucc_warn("Active Sets are only supported for bcast");
return UCC_ERR_NOT_SUPPORTED;
}

Expand Down
7 changes: 4 additions & 3 deletions src/utils/ucc_coll_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ size_t ucc_coll_args_msgsize(const ucc_coll_args_t *args,
ucc_memory_type_t ucc_coll_args_mem_type(const ucc_coll_args_t *args,
ucc_rank_t rank);


/* Convert rank from subset space to rank space (UCC team space) */
static inline ucc_rank_t ucc_ep_map_eval(ucc_ep_map_t map, ucc_rank_t rank)
{
ucc_rank_t r;
Expand Down Expand Up @@ -261,8 +261,9 @@ static inline size_t ucc_buffer_block_offset(size_t total_count,

/* Given the rank space A (e.g. core ucc team), a subset B (e.g. active set
within the core team), the ep_map that maps ranks from the subset B to A,
and the rank of a process within A.
The function below computes the local rank of the process within subset B. */
and the rank of a process within A. The function below computes the local
rank of the process within subset B.
i.e., convert from rank space (UCC team) to subset space */
static inline ucc_rank_t ucc_ep_map_local_rank(ucc_ep_map_t map,
ucc_rank_t rank)
{
Expand Down
Loading

0 comments on commit b122da3

Please sign in to comment.