Skip to content

Commit

Permalink
TL/UCP: bruck bcopy
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergei-Lebedev committed Nov 15, 2023
1 parent 17fc10e commit 03d0052
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 22 deletions.
141 changes: 119 additions & 22 deletions src/components/tl/ucp/alltoall/alltoall_bruck.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

enum {
PHASE_MERGE,
PHASE_SENDRECV
PHASE_SENDRECV,
PHASE_BCOPY
};

static inline int msb_pos_for_level(unsigned int nthbit, ucc_rank_t number)
Expand All @@ -33,7 +34,8 @@ static inline int msb_pos_for_level(unsigned int nthbit, ucc_rank_t number)
return msb_set;
}

static inline int find_seg_index(ucc_rank_t seg_index, int level, int nsegs_per_rblock)
static inline int find_seg_index(ucc_rank_t seg_index, int level,
int nsegs_per_rblock)
{
int block, blockseg;

Expand All @@ -53,7 +55,8 @@ static inline int find_seg_index(ucc_rank_t seg_index, int level, int nsegs_per_
return block * nsegs_per_rblock + blockseg;
}

ucc_status_t ucc_tl_ucp_alltoall_bruck_backward_rotation(void *dst, void *src,
ucc_status_t ucc_tl_ucp_alltoall_bruck_backward_rotation(void *dst,
void *src,
ucc_rank_t trank,
ucc_rank_t tsize,
size_t seg_size)
Expand Down Expand Up @@ -107,18 +110,38 @@ void ucc_tl_ucp_alltoall_bruck_progress(ucc_coll_task_t *coll_task)
ucc_rank_t tsize = UCC_TL_TEAM_SIZE(team);
ucc_coll_args_t *args = &TASK_ARGS(task);
void *scratch = task->alltoall_bruck.scratch_mc_header->addr;
void *mergebuf = args->dst.info.buffer;
void *mergebuf = task->alltoall_bruck.dst;
const ucc_rank_t nrecv_segs = tsize / 2;
const size_t seg_size = ucc_dt_size(args->src.info.datatype) *
args->src.info.count / tsize;
void *data;
ucc_memory_type_t smtype = args->src.info.mem_type;
ucc_memory_type_t dmtype = args->dst.info.mem_type;
ucc_rank_t sendto, recvfrom, step, index;
ucc_rank_t level, snd_count;
int send_buffer_index;
ucc_status_t st;
ucc_ee_executor_t *exec;
ucc_ee_executor_task_args_t eargs;

if (task->alltoall_bruck.etask != NULL) {
st = ucc_ee_executor_task_test(task->alltoall_bruck.etask);
if (st == UCC_OK) {
ucc_ee_executor_task_finalize(task->alltoall_bruck.etask);
task->alltoall_bruck.etask = NULL;
} else {
if (ucc_unlikely(st < 0)) {
task->super.status = st;
}
return;
}
}

if (task->alltoall_bruck.phase == PHASE_SENDRECV) {
goto ALLTOALL_BRUCK_PHASE_SENDRECV;
} else if (task->alltoall_bruck.phase == PHASE_BCOPY) {
task->super.status = UCC_OK;
goto out;
}

step = 1 << (task->alltoall_bruck.iteration - 1);
Expand All @@ -133,7 +156,7 @@ void ucc_tl_ucp_alltoall_bruck_progress(ucc_coll_task_t *coll_task)
index = GET_NEXT_BRUCK_NUM(index, RADIX, step)) {
send_buffer_index = find_seg_index(index, level + 1, nrecv_segs);
if (send_buffer_index == -1) {
data = PTR_OFFSET(args->src.info.buffer,
data = PTR_OFFSET(task->alltoall_bruck.src,
((index + trank) % tsize) * seg_size);
} else {
data = PTR_OFFSET(scratch, send_buffer_index * seg_size);
Expand Down Expand Up @@ -165,29 +188,89 @@ void ucc_tl_ucp_alltoall_bruck_progress(ucc_coll_task_t *coll_task)
step = 1 << (task->alltoall_bruck.iteration - 1);
}

st = ucc_mc_memcpy(PTR_OFFSET(args->dst.info.buffer, trank * seg_size),
PTR_OFFSET(args->src.info.buffer, trank * seg_size),
st = ucc_mc_memcpy(PTR_OFFSET(task->alltoall_bruck.dst, trank * seg_size),
PTR_OFFSET(task->alltoall_bruck.src, trank * seg_size),
seg_size, UCC_MEMORY_TYPE_HOST, UCC_MEMORY_TYPE_HOST);
if (ucc_unlikely(st != UCC_OK)) {
task->super.status = st;
return;
}
task->super.status =
ucc_tl_ucp_alltoall_bruck_backward_rotation(args->dst.info.buffer,
scratch, trank, tsize,
seg_size);
st = ucc_tl_ucp_alltoall_bruck_backward_rotation(task->alltoall_bruck.dst,
scratch, trank, tsize,
seg_size);
if (ucc_unlikely(st != UCC_OK)) {
task->super.status = st;
return;
}

if (smtype != UCC_MEMORY_TYPE_HOST || dmtype != UCC_MEMORY_TYPE_HOST) {
task->alltoall_bruck.phase = PHASE_BCOPY;
st = ucc_coll_task_get_executor(&task->super, &exec);
if (ucc_unlikely(st != UCC_OK)) {
task->super.status = st;
return;
}

eargs.task_type = UCC_EE_EXECUTOR_TASK_COPY;
eargs.copy.src = task->alltoall_bruck.dst;
eargs.copy.dst = args->dst.info.buffer;
eargs.copy.len = seg_size * tsize;
st = ucc_ee_executor_task_post(exec, &eargs,
&task->alltoall_bruck.etask);
if (ucc_unlikely(st != UCC_OK)) {
task->super.status = st;
return;
}
st = ucc_ee_executor_task_test(task->alltoall_bruck.etask);
if (st == UCC_OK) {
ucc_ee_executor_task_finalize(task->alltoall_bruck.etask);
task->alltoall_bruck.etask = NULL;
} else {
if (ucc_unlikely(st < 0)) {
task->super.status = st;
}
return;
}
}

task->super.status = UCC_OK;
out:
return;
}

ucc_status_t ucc_tl_ucp_alltoall_bruck_start(ucc_coll_task_t *coll_task)
{
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_coll_args_t *args = &TASK_ARGS(task);
size_t size = ucc_dt_size(args->src.info.datatype) *
args->src.info.count;
ucc_ee_executor_t *exec;
ucc_ee_executor_task_args_t eargs;
ucc_status_t status;

ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);
task->alltoall_bruck.iteration = 1;
task->alltoall_bruck.phase = PHASE_MERGE;
ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);
task->alltoall_bruck.etask = NULL;

if ((args->src.info.mem_type != UCC_MEMORY_TYPE_HOST) ||
(args->dst.info.mem_type != UCC_MEMORY_TYPE_HOST)) {
status = ucc_coll_task_get_executor(&task->super, &exec);
if (ucc_unlikely(status != UCC_OK)) {
return status;
}

eargs.task_type = UCC_EE_EXECUTOR_TASK_COPY;
eargs.copy.src = args->src.info.buffer;
eargs.copy.dst = task->alltoall_bruck.src;
eargs.copy.len = size;
status = ucc_ee_executor_task_post(exec, &eargs,
&task->alltoall_bruck.etask);
if (ucc_unlikely(status != UCC_OK)) {
return status;
}
}

return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
}
Expand All @@ -199,25 +282,28 @@ ucc_status_t ucc_tl_ucp_alltoall_bruck_init(ucc_base_coll_args_t *coll_args,
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
ucc_rank_t tsize = UCC_TL_TEAM_SIZE(tl_team);
ucc_coll_args_t *args = &coll_args->args;
size_t seg_size = ucc_dt_size(args->src.info.datatype) *
args->src.info.count / tsize;
size_t ssize = ucc_dt_size(args->src.info.datatype) *
args->src.info.count;
size_t seg_size = ssize / tsize;
int bcopy = 0;
size_t scratch_size;
ucc_tl_ucp_task_t *task;
ucc_status_t status;

if ((coll_args->args.src.info.mem_type != UCC_MEMORY_TYPE_HOST) ||
(coll_args->args.dst.info.mem_type != UCC_MEMORY_TYPE_HOST)) {
status = UCC_ERR_NOT_SUPPORTED;
goto out;
}
ALLTOALL_TASK_CHECK(coll_args->args, tl_team);

task = ucc_tl_ucp_init_task(coll_args, team);
task->super.post = ucc_tl_ucp_alltoall_bruck_start;
task->super.progress = ucc_tl_ucp_alltoall_bruck_progress;
task->super.finalize = ucc_tl_ucp_alltoall_bruck_finalize;
task->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR;

scratch_size = lognum(tsize) * ucc_div_round_up(tsize, 2) * seg_size;
if ((coll_args->args.src.info.mem_type != UCC_MEMORY_TYPE_HOST) ||
(coll_args->args.dst.info.mem_type != UCC_MEMORY_TYPE_HOST)) {
bcopy = 1;
scratch_size += 2 * ssize;
}

status = ucc_mc_alloc(&task->alltoall_bruck.scratch_mc_header,
scratch_size, UCC_MEMORY_TYPE_HOST);
if (ucc_unlikely(status != UCC_OK)) {
Expand All @@ -226,6 +312,17 @@ ucc_status_t ucc_tl_ucp_alltoall_bruck_init(ucc_base_coll_args_t *coll_args,
return status;
}

if (bcopy) {
task->alltoall_bruck.src =
PTR_OFFSET(task->alltoall_bruck.scratch_mc_header->addr,
lognum(tsize) * ucc_div_round_up(tsize, 2) * seg_size);
task->alltoall_bruck.dst =
PTR_OFFSET(task->alltoall_bruck.src, ssize);
} else {
task->alltoall_bruck.src = args->src.info.buffer;
task->alltoall_bruck.dst = args->dst.info.buffer;
}

*task_h = &task->super;
return UCC_OK;

Expand Down
3 changes: 3 additions & 0 deletions src/components/tl/ucp/tl_ucp_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,9 @@ typedef struct ucc_tl_ucp_task {
} alltoallv_hybrid;
struct {
ucc_mc_buffer_header_t *scratch_mc_header;
ucc_ee_executor_task_t *etask;
void *src;
void *dst;
ucc_rank_t iteration;
int phase;
} alltoall_bruck;
Expand Down

0 comments on commit 03d0052

Please sign in to comment.