Skip to content

Commit

Permalink
REVIEW: fix review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergei-Lebedev committed Dec 13, 2024
1 parent 887440a commit 1f2ba86
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 13 deletions.
4 changes: 2 additions & 2 deletions src/coll_patterns/recursive_knomial.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ static inline ucc_rank_t ucc_kn_pattern_radix_pow_init(ucc_knomial_pattern_t *p,
static inline void
ucc_knomial_pattern_init_impl(ucc_rank_t size, ucc_rank_t rank,
ucc_kn_radix_t radix, ucc_knomial_pattern_t *p,
int backward, int extra)
int backward, int has_extra)
{
ucc_rank_t fs = radix;
ucc_rank_t n_full_subtrees;
Expand All @@ -102,7 +102,7 @@ ucc_knomial_pattern_init_impl(ucc_rank_t size, ucc_rank_t rank,
p->backward = backward;
p->iteration = 0;
n_full_subtrees = ucc_kn_pattern_n_full(p);
p->n_extra = extra ? size - n_full_subtrees * p->full_pow_size : 0;
p->n_extra = has_extra ? size - n_full_subtrees * p->full_pow_size : 0;
p->n_iters = (p->n_extra && n_full_subtrees == 1) ?
p->pow_radix_sup - 1 : p->pow_radix_sup;
p->radix_pow = ucc_kn_pattern_radix_pow_init(p, backward);
Expand Down
27 changes: 16 additions & 11 deletions src/components/tl/ucp/gather/gather_knomial.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,20 @@
task->gather_kn.phase = _phase; \
} while (0)

static inline uint32_t calc_buffer_size(ucc_rank_t trank, uint32_t radix,
static inline uint32_t calc_buffer_size(ucc_rank_t vrank, uint32_t radix,
ucc_rank_t tsize)
{
uint32_t radix_valuation;

if (trank == 0) {
if (vrank == 0) {
return tsize;
}

radix_valuation = calc_valuation(trank, radix);
return (uint32_t)ucc_min(pow(radix, radix_valuation), tsize - trank);
radix_valuation = calc_valuation(vrank, radix);
return (uint32_t)ucc_min(pow(radix, radix_valuation), tsize - vrank);
}

/* gather knomial is used as regular gather collective and as part of reduce SRG */
void ucc_tl_ucp_gather_knomial_progress(ucc_coll_task_t *coll_task)
{
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
Expand Down Expand Up @@ -86,8 +87,8 @@ void ucc_tl_ucp_gather_knomial_progress(ucc_coll_task_t *coll_task)
task->gather_kn.dist);
}
UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(scratch_offset,
msg_size, mtype, peer,
team, task),
msg_size, mtype, peer,
team, task),
task, out);
} else {
/*
Expand Down Expand Up @@ -176,10 +177,11 @@ void ucc_tl_ucp_gather_knomial_progress(ucc_coll_task_t *coll_task)
root_at_level, team, task),
task, out);
} else {
// need to split in this case due to root and tree topology
msg_size = data_size * (tsize - rank);
UCPCHECK_GOTO(ucc_tl_ucp_send_nb(task->gather_kn.scratch,
msg_size, mtype,
root_at_level, team, task),
msg_size, mtype,
root_at_level, team, task),
task, out);
msg_size = data_size * (num_blocks - (tsize - rank));
UCPCHECK_GOTO(
Expand Down Expand Up @@ -226,6 +228,8 @@ ucc_status_t ucc_tl_ucp_gather_knomial_start(ucc_coll_task_t *coll_task)
task->gather_kn.radix, args->src.info.count * size,
&task->gather_kn.p);
} else {
/* reduce srg */
ucc_assert(args->coll_type == UCC_COLL_TYPE_REDUCE);
task->gather_kn.scratch = args->dst.info.buffer;
ucc_kn_gx_pattern_init(size, VRANK(trank, root, size),
task->gather_kn.radix, args->dst.info.count,
Expand Down Expand Up @@ -265,7 +269,7 @@ ucc_status_t ucc_tl_ucp_gather_knomial_init_common(ucc_tl_ucp_task_t *task,
ucc_datatype_t dt;
size_t count, data_size;
uint32_t buffer_size;
int isleaf;
int is_leaf;

if (UCC_IS_ROOT(*args, trank)) {
count = args->dst.info.count;
Expand All @@ -288,10 +292,11 @@ ucc_status_t ucc_tl_ucp_gather_knomial_init_common(ucc_tl_ucp_task_t *task,
if (args->coll_type == UCC_COLL_TYPE_REDUCE) {
task->gather_kn.scratch = args->dst.info.buffer;
} else {
isleaf = ((vrank % radix != 0) || (vrank == tsize - 1));
ucc_assert(args->coll_type == UCC_COLL_TYPE_GATHER);
is_leaf = ((vrank % radix != 0) || (vrank == tsize - 1));
if (vrank == 0) {
task->gather_kn.scratch = args->dst.info.buffer;
} else if (isleaf) {
} else if (is_leaf) {
task->gather_kn.scratch = args->src.info.buffer;
} else {
buffer_size = calc_buffer_size(vrank, task->gather_kn.radix, tsize);
Expand Down

0 comments on commit 1f2ba86

Please sign in to comment.