diff --git a/src/components/tl/ucp/reduce/reduce_dbt.c b/src/components/tl/ucp/reduce/reduce_dbt.c index 8d85cb0dae..1bc9933169 100644 --- a/src/components/tl/ucp/reduce/reduce_dbt.c +++ b/src/components/tl/ucp/reduce/reduce_dbt.c @@ -53,7 +53,7 @@ static void recv_completion_1(void *request, ucs_status_t status, { ucc_tl_ucp_task_t *task = (ucc_tl_ucp_task_t *)user_data; - task->reduce_dbt.t1.recv++; + task->reduce_dbt.trees[0].recv++; recv_completion_common(request, status, info, user_data); } @@ -63,7 +63,7 @@ static void recv_completion_2(void *request, ucs_status_t status, { ucc_tl_ucp_task_t *task = (ucc_tl_ucp_task_t *)user_data; - task->reduce_dbt.t2.recv++; + task->reduce_dbt.trees[1].recv++; recv_completion_common(request, status, info, user_data); } @@ -93,23 +93,22 @@ static inline void single_tree_reduce(ucc_tl_ucp_task_t *task, void *sbuf, void ucc_tl_ucp_reduce_dbt_progress(ucc_coll_task_t *coll_task) { - ucc_tl_ucp_task_t *task = - ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, + ucc_tl_ucp_task_t); ucc_tl_ucp_team_t *team = TASK_TEAM(task); ucc_coll_args_t *args = &TASK_ARGS(task); + ucc_dbt_single_tree_t *trees = task->reduce_dbt.trees ; ucc_rank_t rank = UCC_TL_TEAM_RANK(team); ucc_rank_t coll_root = (ucc_rank_t)args->root; int is_root = rank == coll_root; - ucc_dbt_single_tree_t t1 = task->reduce_dbt.t1; - ucc_dbt_single_tree_t t2 = task->reduce_dbt.t2; ucp_tag_recv_nbx_callback_t cb[2] = {recv_completion_1, recv_completion_2}; - void *t1_sbuf, *t1_rbuf, *t2_sbuf, *t2_rbuf; - uint32_t i, j; + void *sbuf[2], *rbuf[2]; + uint32_t i, j, k; ucc_memory_type_t mtype; ucc_datatype_t dt; - size_t count, count_t1, data_size, data_size_t1, - data_size_t2; + size_t count, data_size, data_size_t1; + size_t counts[2]; int avg_pre_op, avg_post_op; if (is_root) { @@ -122,88 +121,67 @@ void ucc_tl_ucp_reduce_dbt_progress(ucc_coll_task_t *coll_task) count = args->src.info.count; } - count_t1 = (count % 2) ? count / 2 + 1 : count / 2; + counts[0] = (count % 2) ? count / 2 + 1 : count / 2; + counts[1] = count / 2; data_size = count * ucc_dt_size(dt); - data_size_t1 = count_t1 * ucc_dt_size(dt); - data_size_t2 = count / 2 * ucc_dt_size(dt); + data_size_t1 = counts[0] * ucc_dt_size(dt); avg_pre_op = ((args->op == UCC_OP_AVG) && UCC_TL_UCP_TEAM_LIB(team)->cfg.reduce_avg_pre_op); avg_post_op = ((args->op == UCC_OP_AVG) && !UCC_TL_UCP_TEAM_LIB(team)->cfg.reduce_avg_pre_op); - t1_rbuf = task->reduce_dbt.scratch; - t2_rbuf = PTR_OFFSET(t1_rbuf, data_size_t1 * 2); - t1_sbuf = avg_pre_op ? PTR_OFFSET(t1_rbuf, data_size * 2) - : args->src.info.buffer; - t2_sbuf = PTR_OFFSET(t1_sbuf, data_size_t1); + + rbuf[0] = task->reduce_dbt.scratch; + rbuf[1] = PTR_OFFSET(rbuf[0], data_size_t1 * 2);; + sbuf[0] = avg_pre_op ? PTR_OFFSET(rbuf[0], data_size * 2) + : args->src.info.buffer;; + sbuf[1] = PTR_OFFSET(sbuf[0], data_size_t1); UCC_REDUCE_DBT_GOTO_STATE(task->reduce_dbt.state); - j = 0; for (i = 0; i < 2; i++) { - if (t1.children[i] != UCC_RANK_INVALID) { - UCPCHECK_GOTO(ucc_tl_ucp_recv_cb(PTR_OFFSET(t1_rbuf, - data_size_t1 * j), - data_size_t1, mtype, - t1.children[i], team, task, cb[0], - (void *)task), - task, out); - j++; - } - } + j = 0; + for (k = 0; k < 2; k++) { + if (trees[i].children[k] != UCC_RANK_INVALID) { + UCPCHECK_GOTO(ucc_tl_ucp_recv_cb( + PTR_OFFSET(rbuf[i], counts[i] * ucc_dt_size(dt) * j), + counts[i] * ucc_dt_size(dt), mtype, + trees[i].children[k], team, task, cb[i], + (void *)task), + task, out); + j++; + } - j = 0; - for (i = 0; i < 2; i++) { - if (t2.children[i] != UCC_RANK_INVALID) { - UCPCHECK_GOTO(ucc_tl_ucp_recv_cb(PTR_OFFSET(t2_rbuf, - data_size_t2 * j), - data_size_t2, mtype, - t2.children[i], team, task, cb[1], - (void *)task), - task, out); - j++; } } task->reduce_dbt.state = REDUCE; REDUCE: - if (t1.recv == t1.n_children && !task->reduce_dbt.t1_reduction_comp) { - if (t1.n_children > 0) { - single_tree_reduce(task, t1_sbuf, t1_rbuf, t1.n_children, count_t1, - data_size_t1, dt, args, - avg_post_op && t1.root == rank); - } - task->reduce_dbt.t1_reduction_comp = 1; - } - if (t2.recv == t2.n_children && !task->reduce_dbt.t2_reduction_comp) { - if (t2.n_children > 0) { - single_tree_reduce(task, t2_sbuf, t2_rbuf, t2.n_children, - count / 2, data_size_t2, dt, args, - avg_post_op && t2.root == rank); + for (i = 0; i < 2; i++) { + if (trees[i].recv == trees[i].n_children && + !task->reduce_dbt.reduction_comp[i]) { + if (trees[i].n_children > 0) { + single_tree_reduce(task, sbuf[i], rbuf[i], trees[i].n_children, + counts[i], counts[i] * ucc_dt_size(dt), dt, + args, avg_post_op && trees[i].root == rank); + } + task->reduce_dbt.reduction_comp[i] = 1; } - task->reduce_dbt.t2_reduction_comp = 1; - } - - if (rank != t1.root && task->reduce_dbt.t1_reduction_comp && - !task->reduce_dbt.t1_send_comp) { - UCPCHECK_GOTO(ucc_tl_ucp_send_nb((t1.n_children > 0) ? t1_rbuf - : t1_sbuf, - data_size_t1, mtype, t1.parent, team, - task), - task, out); - task->reduce_dbt.t1_send_comp = 1; } - if (rank != t2.root && task->reduce_dbt.t2_reduction_comp && - !task->reduce_dbt.t2_send_comp) { - UCPCHECK_GOTO(ucc_tl_ucp_send_nb((t2.n_children > 0) ? t2_rbuf - : t2_sbuf, - data_size_t2, mtype, t2.parent, team, - task), - task, out); - task->reduce_dbt.t2_send_comp = 1; + for (i = 0; i < 2; i++) { + if (rank != trees[i].root && task->reduce_dbt.reduction_comp[i] && + !task->reduce_dbt.send_comp[i]) { + UCPCHECK_GOTO(ucc_tl_ucp_send_nb((trees[i].n_children > 0) ? rbuf[i] + : sbuf[i], + counts[i] * ucc_dt_size(dt), + mtype, trees[i].parent, team, + task), + task, out); + task->reduce_dbt.send_comp[i] = 1; + } } - if (!task->reduce_dbt.t1_reduction_comp || - !task->reduce_dbt.t2_reduction_comp) { + if (!task->reduce_dbt.reduction_comp[0] || + !task->reduce_dbt.reduction_comp[1]) { return; } TEST: @@ -213,55 +191,45 @@ void ucc_tl_ucp_reduce_dbt_progress(ucc_coll_task_t *coll_task) } /* tree roots send to coll root*/ - if (rank == t1.root && !is_root) { - UCPCHECK_GOTO(ucc_tl_ucp_send_nb(t1_rbuf, data_size_t1, mtype, - coll_root, team, task), - task, out); - } - - if (rank == t2.root && !is_root) { - UCPCHECK_GOTO(ucc_tl_ucp_send_nb(t2_rbuf, data_size_t2, mtype, - coll_root, team, task), - task, out); + for (i = 0; i < 2; i++) { + if (rank == trees[i].root && !is_root) { + UCPCHECK_GOTO(ucc_tl_ucp_send_nb(rbuf[i], + counts[i] * ucc_dt_size(dt), + mtype, coll_root, team, task), + task, out); + } } - task->reduce_dbt.t1_reduction_comp = t1.recv; - task->reduce_dbt.t2_reduction_comp = t2.recv; + task->reduce_dbt.reduction_comp[0] = trees[0].recv; + task->reduce_dbt.reduction_comp[1] = trees[1].recv; - if (is_root && rank != t1.root) { - UCPCHECK_GOTO(ucc_tl_ucp_recv_cb(args->dst.info.buffer, data_size_t1, - mtype, t1.root, team, task, cb[0], - (void *)task), - task, out); - task->reduce_dbt.t1_reduction_comp++; + for (i = 0; i < 2; i++) { + if (is_root && rank != trees[i].root) { + UCPCHECK_GOTO(ucc_tl_ucp_recv_cb(PTR_OFFSET(args->dst.info.buffer, + i * counts[0] * ucc_dt_size(dt)), + counts[i] * ucc_dt_size(dt), + mtype, trees[i].root, team, task, + cb[i], (void *)task), + task, out); + task->reduce_dbt.reduction_comp[i]++; + } } - if (is_root && rank != t2.root) { - UCPCHECK_GOTO(ucc_tl_ucp_recv_cb(PTR_OFFSET(args->dst.info.buffer, - data_size_t1), - data_size_t2, mtype, t2.root, team, - task, cb[1], (void *)task), - task, out); - task->reduce_dbt.t2_reduction_comp++; - } TEST_ROOT: if (UCC_INPROGRESS == ucc_tl_ucp_test_send(task) || - task->reduce_dbt.t1_reduction_comp != t1.recv || - task->reduce_dbt.t2_reduction_comp != t2.recv) { + task->reduce_dbt.reduction_comp[0] != trees[0].recv || + task->reduce_dbt.reduction_comp[1] != trees[1].recv) { task->reduce_dbt.state = TEST_ROOT; return; } - if (is_root && rank == t1.root) { - UCPCHECK_GOTO(ucc_mc_memcpy(args->dst.info.buffer, t1_rbuf, - data_size_t1, mtype, mtype), task, out); - } - - if (is_root && rank == t2.root) { - UCPCHECK_GOTO(ucc_mc_memcpy(PTR_OFFSET(args->dst.info.buffer, - data_size_t1), t2_rbuf, - data_size_t2, mtype, mtype), - task, out); + for (i = 0; i < 2; i++) { + if (is_root && rank == trees[i].root) { + UCPCHECK_GOTO(ucc_mc_memcpy(PTR_OFFSET(args->dst.info.buffer, + i * counts[i] * ucc_dt_size(dt)), + rbuf[i], counts[i] * ucc_dt_size(dt), + mtype, mtype), task, out); + } } task->super.status = UCC_OK; @@ -284,12 +252,13 @@ ucc_status_t ucc_tl_ucp_reduce_dbt_start(ucc_coll_task_t *coll_task) size_t count, data_size; ucc_status_t status; - task->reduce_dbt.t1.recv = 0; - task->reduce_dbt.t2.recv = 0; - task->reduce_dbt.t1_reduction_comp = 0; - task->reduce_dbt.t2_reduction_comp = 0; - task->reduce_dbt.t1_send_comp = 0; - task->reduce_dbt.t2_send_comp = 0; + task->reduce_dbt.trees[0].recv = 0; + task->reduce_dbt.trees[1].recv = 0; + task->reduce_dbt.reduction_comp[0] = 0; + task->reduce_dbt.reduction_comp[1] = 0; + task->reduce_dbt.send_comp[0] = 0; + task->reduce_dbt.send_comp[1] = 0; + ucc_tl_ucp_task_reset(task, UCC_INPROGRESS); if (args->root == rank) { @@ -364,8 +333,8 @@ ucc_status_t ucc_tl_ucp_reduce_dbt_init(ucc_base_coll_args_t *coll_args, tl_team = TASK_TEAM(task); rank = UCC_TL_TEAM_RANK(tl_team); size = UCC_TL_TEAM_SIZE(tl_team); - ucc_dbt_build_trees(rank, size, &task->reduce_dbt.t1, - &task->reduce_dbt.t2); + ucc_dbt_build_trees(rank, size, &task->reduce_dbt.trees[0], + &task->reduce_dbt.trees[1]); if (coll_args->args.root == rank) { count = coll_args->args.dst.info.count; @@ -379,7 +348,7 @@ ucc_status_t ucc_tl_ucp_reduce_dbt_init(ucc_base_coll_args_t *coll_args, data_size = count * ucc_dt_size(dt); task->reduce_dbt.scratch_mc_header = NULL; status = ucc_mc_alloc(&task->reduce_dbt.scratch_mc_header, 3 * data_size, - mtype); + mtype); if (ucc_unlikely(status != UCC_OK)) { return status; } diff --git a/src/components/tl/ucp/tl_ucp_coll.h b/src/components/tl/ucp/tl_ucp_coll.h index c4d11a265f..6ab2c661dd 100644 --- a/src/components/tl/ucp/tl_ucp_coll.h +++ b/src/components/tl/ucp/tl_ucp_coll.h @@ -202,13 +202,10 @@ typedef struct ucc_tl_ucp_task { ucc_ee_executor_t *executor; } reduce_kn; struct { - ucc_dbt_single_tree_t t1; - ucc_dbt_single_tree_t t2; int state; - int t1_reduction_comp; - int t2_reduction_comp; - int t1_send_comp; - int t2_send_comp; + ucc_dbt_single_tree_t trees[2]; + int reduction_comp[2]; + int send_comp[2]; void *scratch; ucc_mc_buffer_header_t *scratch_mc_header; ucc_ee_executor_task_t *etask;