Skip to content

Commit

Permalink
REVIEW: fix review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergei-Lebedev committed Jan 16, 2024
1 parent f806044 commit 27a4c9f
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 126 deletions.
209 changes: 89 additions & 120 deletions src/components/tl/ucp/reduce/reduce_dbt.c
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ static void recv_completion_1(void *request, ucs_status_t status,
{
ucc_tl_ucp_task_t *task = (ucc_tl_ucp_task_t *)user_data;

task->reduce_dbt.t1.recv++;
task->reduce_dbt.trees[0].recv++;
recv_completion_common(request, status, info, user_data);
}

Expand All @@ -63,7 +63,7 @@ static void recv_completion_2(void *request, ucs_status_t status,
{
ucc_tl_ucp_task_t *task = (ucc_tl_ucp_task_t *)user_data;

task->reduce_dbt.t2.recv++;
task->reduce_dbt.trees[1].recv++;
recv_completion_common(request, status, info, user_data);
}

Expand Down Expand Up @@ -93,23 +93,22 @@ static inline void single_tree_reduce(ucc_tl_ucp_task_t *task, void *sbuf,

void ucc_tl_ucp_reduce_dbt_progress(ucc_coll_task_t *coll_task)
{
ucc_tl_ucp_task_t *task =
ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task,
ucc_tl_ucp_task_t);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_coll_args_t *args = &TASK_ARGS(task);
ucc_dbt_single_tree_t *trees = task->reduce_dbt.trees ;
ucc_rank_t rank = UCC_TL_TEAM_RANK(team);
ucc_rank_t coll_root = (ucc_rank_t)args->root;
int is_root = rank == coll_root;
ucc_dbt_single_tree_t t1 = task->reduce_dbt.t1;
ucc_dbt_single_tree_t t2 = task->reduce_dbt.t2;
ucp_tag_recv_nbx_callback_t cb[2] = {recv_completion_1,
recv_completion_2};
void *t1_sbuf, *t1_rbuf, *t2_sbuf, *t2_rbuf;
uint32_t i, j;
void *sbuf[2], *rbuf[2];
uint32_t i, j, k;
ucc_memory_type_t mtype;
ucc_datatype_t dt;
size_t count, count_t1, data_size, data_size_t1,
data_size_t2;
size_t count, data_size, data_size_t1;
size_t counts[2];
int avg_pre_op, avg_post_op;

if (is_root) {
Expand All @@ -122,88 +121,67 @@ void ucc_tl_ucp_reduce_dbt_progress(ucc_coll_task_t *coll_task)
count = args->src.info.count;
}

count_t1 = (count % 2) ? count / 2 + 1 : count / 2;
counts[0] = (count % 2) ? count / 2 + 1 : count / 2;
counts[1] = count / 2;
data_size = count * ucc_dt_size(dt);
data_size_t1 = count_t1 * ucc_dt_size(dt);
data_size_t2 = count / 2 * ucc_dt_size(dt);
data_size_t1 = counts[0] * ucc_dt_size(dt);
avg_pre_op = ((args->op == UCC_OP_AVG) &&
UCC_TL_UCP_TEAM_LIB(team)->cfg.reduce_avg_pre_op);
avg_post_op = ((args->op == UCC_OP_AVG) &&
!UCC_TL_UCP_TEAM_LIB(team)->cfg.reduce_avg_pre_op);
t1_rbuf = task->reduce_dbt.scratch;
t2_rbuf = PTR_OFFSET(t1_rbuf, data_size_t1 * 2);
t1_sbuf = avg_pre_op ? PTR_OFFSET(t1_rbuf, data_size * 2)
: args->src.info.buffer;
t2_sbuf = PTR_OFFSET(t1_sbuf, data_size_t1);

rbuf[0] = task->reduce_dbt.scratch;
rbuf[1] = PTR_OFFSET(rbuf[0], data_size_t1 * 2);;
sbuf[0] = avg_pre_op ? PTR_OFFSET(rbuf[0], data_size * 2)
: args->src.info.buffer;;
sbuf[1] = PTR_OFFSET(sbuf[0], data_size_t1);

UCC_REDUCE_DBT_GOTO_STATE(task->reduce_dbt.state);
j = 0;
for (i = 0; i < 2; i++) {
if (t1.children[i] != UCC_RANK_INVALID) {
UCPCHECK_GOTO(ucc_tl_ucp_recv_cb(PTR_OFFSET(t1_rbuf,
data_size_t1 * j),
data_size_t1, mtype,
t1.children[i], team, task, cb[0],
(void *)task),
task, out);
j++;
}
}
j = 0;
for (k = 0; k < 2; k++) {
if (trees[i].children[k] != UCC_RANK_INVALID) {
UCPCHECK_GOTO(ucc_tl_ucp_recv_cb(
PTR_OFFSET(rbuf[i], counts[i] * ucc_dt_size(dt) * j),
counts[i] * ucc_dt_size(dt), mtype,
trees[i].children[k], team, task, cb[i],
(void *)task),
task, out);
j++;
}

j = 0;
for (i = 0; i < 2; i++) {
if (t2.children[i] != UCC_RANK_INVALID) {
UCPCHECK_GOTO(ucc_tl_ucp_recv_cb(PTR_OFFSET(t2_rbuf,
data_size_t2 * j),
data_size_t2, mtype,
t2.children[i], team, task, cb[1],
(void *)task),
task, out);
j++;
}
}
task->reduce_dbt.state = REDUCE;

REDUCE:
if (t1.recv == t1.n_children && !task->reduce_dbt.t1_reduction_comp) {
if (t1.n_children > 0) {
single_tree_reduce(task, t1_sbuf, t1_rbuf, t1.n_children, count_t1,
data_size_t1, dt, args,
avg_post_op && t1.root == rank);
}
task->reduce_dbt.t1_reduction_comp = 1;
}
if (t2.recv == t2.n_children && !task->reduce_dbt.t2_reduction_comp) {
if (t2.n_children > 0) {
single_tree_reduce(task, t2_sbuf, t2_rbuf, t2.n_children,
count / 2, data_size_t2, dt, args,
avg_post_op && t2.root == rank);
for (i = 0; i < 2; i++) {
if (trees[i].recv == trees[i].n_children &&
!task->reduce_dbt.reduction_comp[i]) {
if (trees[i].n_children > 0) {
single_tree_reduce(task, sbuf[i], rbuf[i], trees[i].n_children,
counts[i], counts[i] * ucc_dt_size(dt), dt,
args, avg_post_op && trees[i].root == rank);
}
task->reduce_dbt.reduction_comp[i] = 1;
}
task->reduce_dbt.t2_reduction_comp = 1;
}

if (rank != t1.root && task->reduce_dbt.t1_reduction_comp &&
!task->reduce_dbt.t1_send_comp) {
UCPCHECK_GOTO(ucc_tl_ucp_send_nb((t1.n_children > 0) ? t1_rbuf
: t1_sbuf,
data_size_t1, mtype, t1.parent, team,
task),
task, out);
task->reduce_dbt.t1_send_comp = 1;
}

if (rank != t2.root && task->reduce_dbt.t2_reduction_comp &&
!task->reduce_dbt.t2_send_comp) {
UCPCHECK_GOTO(ucc_tl_ucp_send_nb((t2.n_children > 0) ? t2_rbuf
: t2_sbuf,
data_size_t2, mtype, t2.parent, team,
task),
task, out);
task->reduce_dbt.t2_send_comp = 1;
for (i = 0; i < 2; i++) {
if (rank != trees[i].root && task->reduce_dbt.reduction_comp[i] &&
!task->reduce_dbt.send_comp[i]) {
UCPCHECK_GOTO(ucc_tl_ucp_send_nb((trees[i].n_children > 0) ? rbuf[i]
: sbuf[i],
counts[i] * ucc_dt_size(dt),
mtype, trees[i].parent, team,
task),
task, out);
task->reduce_dbt.send_comp[i] = 1;
}
}

if (!task->reduce_dbt.t1_reduction_comp ||
!task->reduce_dbt.t2_reduction_comp) {
if (!task->reduce_dbt.reduction_comp[0] ||
!task->reduce_dbt.reduction_comp[1]) {
return;
}
TEST:
Expand All @@ -213,55 +191,45 @@ void ucc_tl_ucp_reduce_dbt_progress(ucc_coll_task_t *coll_task)
}

/* tree roots send to coll root*/
if (rank == t1.root && !is_root) {
UCPCHECK_GOTO(ucc_tl_ucp_send_nb(t1_rbuf, data_size_t1, mtype,
coll_root, team, task),
task, out);
}

if (rank == t2.root && !is_root) {
UCPCHECK_GOTO(ucc_tl_ucp_send_nb(t2_rbuf, data_size_t2, mtype,
coll_root, team, task),
task, out);
for (i = 0; i < 2; i++) {
if (rank == trees[i].root && !is_root) {
UCPCHECK_GOTO(ucc_tl_ucp_send_nb(rbuf[i],
counts[i] * ucc_dt_size(dt),
mtype, coll_root, team, task),
task, out);
}
}

task->reduce_dbt.t1_reduction_comp = t1.recv;
task->reduce_dbt.t2_reduction_comp = t2.recv;
task->reduce_dbt.reduction_comp[0] = trees[0].recv;
task->reduce_dbt.reduction_comp[1] = trees[1].recv;

if (is_root && rank != t1.root) {
UCPCHECK_GOTO(ucc_tl_ucp_recv_cb(args->dst.info.buffer, data_size_t1,
mtype, t1.root, team, task, cb[0],
(void *)task),
task, out);
task->reduce_dbt.t1_reduction_comp++;
for (i = 0; i < 2; i++) {
if (is_root && rank != trees[i].root) {
UCPCHECK_GOTO(ucc_tl_ucp_recv_cb(PTR_OFFSET(args->dst.info.buffer,
i * counts[0] * ucc_dt_size(dt)),
counts[i] * ucc_dt_size(dt),
mtype, trees[i].root, team, task,
cb[i], (void *)task),
task, out);
task->reduce_dbt.reduction_comp[i]++;
}
}

if (is_root && rank != t2.root) {
UCPCHECK_GOTO(ucc_tl_ucp_recv_cb(PTR_OFFSET(args->dst.info.buffer,
data_size_t1),
data_size_t2, mtype, t2.root, team,
task, cb[1], (void *)task),
task, out);
task->reduce_dbt.t2_reduction_comp++;
}
TEST_ROOT:
if (UCC_INPROGRESS == ucc_tl_ucp_test_send(task) ||
task->reduce_dbt.t1_reduction_comp != t1.recv ||
task->reduce_dbt.t2_reduction_comp != t2.recv) {
task->reduce_dbt.reduction_comp[0] != trees[0].recv ||
task->reduce_dbt.reduction_comp[1] != trees[1].recv) {
task->reduce_dbt.state = TEST_ROOT;
return;
}

if (is_root && rank == t1.root) {
UCPCHECK_GOTO(ucc_mc_memcpy(args->dst.info.buffer, t1_rbuf,
data_size_t1, mtype, mtype), task, out);
}

if (is_root && rank == t2.root) {
UCPCHECK_GOTO(ucc_mc_memcpy(PTR_OFFSET(args->dst.info.buffer,
data_size_t1), t2_rbuf,
data_size_t2, mtype, mtype),
task, out);
for (i = 0; i < 2; i++) {
if (is_root && rank == trees[i].root) {
UCPCHECK_GOTO(ucc_mc_memcpy(PTR_OFFSET(args->dst.info.buffer,
i * counts[i] * ucc_dt_size(dt)),
rbuf[i], counts[i] * ucc_dt_size(dt),
mtype, mtype), task, out);
}
}

task->super.status = UCC_OK;
Expand All @@ -284,12 +252,13 @@ ucc_status_t ucc_tl_ucp_reduce_dbt_start(ucc_coll_task_t *coll_task)
size_t count, data_size;
ucc_status_t status;

task->reduce_dbt.t1.recv = 0;
task->reduce_dbt.t2.recv = 0;
task->reduce_dbt.t1_reduction_comp = 0;
task->reduce_dbt.t2_reduction_comp = 0;
task->reduce_dbt.t1_send_comp = 0;
task->reduce_dbt.t2_send_comp = 0;
task->reduce_dbt.trees[0].recv = 0;
task->reduce_dbt.trees[1].recv = 0;
task->reduce_dbt.reduction_comp[0] = 0;
task->reduce_dbt.reduction_comp[1] = 0;
task->reduce_dbt.send_comp[0] = 0;
task->reduce_dbt.send_comp[1] = 0;

ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);

if (args->root == rank) {
Expand Down Expand Up @@ -364,8 +333,8 @@ ucc_status_t ucc_tl_ucp_reduce_dbt_init(ucc_base_coll_args_t *coll_args,
tl_team = TASK_TEAM(task);
rank = UCC_TL_TEAM_RANK(tl_team);
size = UCC_TL_TEAM_SIZE(tl_team);
ucc_dbt_build_trees(rank, size, &task->reduce_dbt.t1,
&task->reduce_dbt.t2);
ucc_dbt_build_trees(rank, size, &task->reduce_dbt.trees[0],
&task->reduce_dbt.trees[1]);

if (coll_args->args.root == rank) {
count = coll_args->args.dst.info.count;
Expand All @@ -379,7 +348,7 @@ ucc_status_t ucc_tl_ucp_reduce_dbt_init(ucc_base_coll_args_t *coll_args,
data_size = count * ucc_dt_size(dt);
task->reduce_dbt.scratch_mc_header = NULL;
status = ucc_mc_alloc(&task->reduce_dbt.scratch_mc_header, 3 * data_size,
mtype);
mtype);
if (ucc_unlikely(status != UCC_OK)) {
return status;
}
Expand Down
9 changes: 3 additions & 6 deletions src/components/tl/ucp/tl_ucp_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,13 +202,10 @@ typedef struct ucc_tl_ucp_task {
ucc_ee_executor_t *executor;
} reduce_kn;
struct {
ucc_dbt_single_tree_t t1;
ucc_dbt_single_tree_t t2;
int state;
int t1_reduction_comp;
int t2_reduction_comp;
int t1_send_comp;
int t2_send_comp;
ucc_dbt_single_tree_t trees[2];
int reduction_comp[2];
int send_comp[2];
void *scratch;
ucc_mc_buffer_header_t *scratch_mc_header;
ucc_ee_executor_task_t *etask;
Expand Down

0 comments on commit 27a4c9f

Please sign in to comment.