Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TL/UCP: Allgather Bruck algorithm #898

Merged
merged 14 commits into from
Feb 26, 2024
1 change: 1 addition & 0 deletions src/components/tl/ucp/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ allgather = \
allgather/allgather.c \
allgather/allgather_ring.c \
allgather/allgather_neighbor.c \
allgather/allgather_bruck.c \
allgather/allgather_knomial.c

allgatherv = \
Expand Down
4 changes: 4 additions & 0 deletions src/components/tl/ucp/allgather/allgather.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ ucc_base_coll_alg_info_t
{.id = UCC_TL_UCP_ALLGATHER_ALG_NEIGHBOR,
.name = "neighbor",
.desc = "O(N) Neighbor Exchange N/2 steps"},
[UCC_TL_UCP_ALLGATHER_ALG_BRUCK] =
{.id = UCC_TL_UCP_ALLGATHER_ALG_BRUCK,
.name = "bruck",
.desc = "O(log(N)) Variation of Bruck algorithm"},
[UCC_TL_UCP_ALLGATHER_ALG_LAST] = {
.id = 0, .name = NULL, .desc = NULL}};

Expand Down
12 changes: 12 additions & 0 deletions src/components/tl/ucp/allgather/allgather.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ enum {
UCC_TL_UCP_ALLGATHER_ALG_KNOMIAL,
UCC_TL_UCP_ALLGATHER_ALG_RING,
UCC_TL_UCP_ALLGATHER_ALG_NEIGHBOR,
UCC_TL_UCP_ALLGATHER_ALG_BRUCK,
UCC_TL_UCP_ALLGATHER_ALG_LAST
};

Expand Down Expand Up @@ -56,6 +57,17 @@ void ucc_tl_ucp_allgather_neighbor_progress(ucc_coll_task_t *task);

ucc_status_t ucc_tl_ucp_allgather_neighbor_start(ucc_coll_task_t *task);

/* Bruck */
ucc_status_t ucc_tl_ucp_allgather_bruck_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h);

void ucc_tl_ucp_allgather_bruck_progress(ucc_coll_task_t *task);

ucc_status_t ucc_tl_ucp_allgather_bruck_start(ucc_coll_task_t *task);

ucc_status_t ucc_tl_ucp_allgather_bruck_finalize(ucc_coll_task_t *coll_task);

/* Uses allgather_kn_radix from config */
ucc_status_t ucc_tl_ucp_allgather_knomial_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
Expand Down
258 changes: 258 additions & 0 deletions src/components/tl/ucp/allgather/allgather_bruck.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
/**
* Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
#include "config.h"
#include "tl_ucp.h"
#include "allgather.h"
#include "core/ucc_progress_queue.h"
#include "tl_ucp_sendrecv.h"
#include "utils/ucc_math.h"
#include "utils/ucc_coll_utils.h"
#include "components/mc/ucc_mc.h"
#include <stdio.h>

ucc_status_t ucc_tl_ucp_allgather_bruck_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h)
{
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
ucc_tl_ucp_task_t *task = ucc_tl_ucp_init_task(coll_args, team);
ucc_status_t status = UCC_OK;
ucc_rank_t trank = UCC_TL_TEAM_RANK(tl_team);
ucc_rank_t tsize = UCC_TL_TEAM_SIZE(tl_team);
ucc_memory_type_t rmem = TASK_ARGS(task).dst.info.mem_type;
ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype;
size_t count = TASK_ARGS(task).dst.info.count;
size_t data_size = (count / tsize) * ucc_dt_size(dt);
size_t scratch_size = (tsize - trank) * data_size;

if (!ucc_coll_args_is_predefined_dt(&TASK_ARGS(task), UCC_RANK_INVALID)) {
tl_error(UCC_TASK_LIB(task), "user defined datatype is not supported");
status = UCC_ERR_NOT_SUPPORTED;
goto out;
}
tl_trace(UCC_TASK_LIB(task), "ucc_tl_ucp_allgather_bruck_init");

task->super.post = ucc_tl_ucp_allgather_bruck_start;
task->super.progress = ucc_tl_ucp_allgather_bruck_progress;
task->super.finalize = ucc_tl_ucp_allgather_bruck_finalize;

/* allocate scratch buffer only on non root rank */
if (trank != 0) {
if (UCC_MEMORY_TYPE_HOST != rmem) {
scratch_size = tsize * data_size;
}
status = ucc_mc_alloc(&task->allgather_bruck.scratch_header,
scratch_size, UCC_MEMORY_TYPE_HOST);
if (ucc_unlikely(status != UCC_OK)) {
tl_error(UCC_TASK_LIB(task), "failed to allocate scratch buffer");
ucc_tl_ucp_coll_finalize(&task->super);
goto out;
}
task->allgather_bruck.scratch_size = scratch_size;
} else {
task->allgather_bruck.scratch_header = NULL;
task->allgather_bruck.scratch_size = 0;
}
out:
if (status != UCC_OK) {
ucc_tl_ucp_put_task(task);
return status;
}

*task_h = &task->super;
return status;
}

ucc_status_t ucc_tl_ucp_allgather_bruck_finalize(ucc_coll_task_t *coll_task)
{
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
ucc_status_t global_status = UCC_OK;
ucc_status_t status;

tl_trace(UCC_TASK_LIB(task), "ucc_tl_ucp_allgather_bruck_finalize");

if (task->allgather_bruck.scratch_header != NULL) {
/* deallocate scratch buffer */
global_status = ucc_mc_free(task->allgather_bruck.scratch_header);
if (ucc_unlikely(global_status != UCC_OK)) {
tl_error(UCC_TASK_LIB(task),
"failed to free scratch buffer memory");
}
task->allgather_bruck.scratch_size = 0;
}

status = ucc_tl_ucp_coll_finalize(&task->super);
if (ucc_unlikely(status != UCC_OK)) {
tl_error(UCC_TASK_LIB(task),
"failed to finalize allgather bruck collective");
global_status = status;
}
return global_status;
}

/* Inspired by implementation: https://github.com/open-mpi/ompi/blob/main/ompi/mca/coll/base/coll_base_allgather.c */
void ucc_tl_ucp_allgather_bruck_progress(ucc_coll_task_t *coll_task)
{
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_rank_t trank = UCC_TL_TEAM_RANK(team);
ucc_rank_t tsize = UCC_TL_TEAM_SIZE(team);
void *rbuf = TASK_ARGS(task).dst.info.buffer;
ucc_memory_type_t rmem = TASK_ARGS(task).dst.info.mem_type;
ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype;
size_t count = TASK_ARGS(task).dst.info.count;
ucc_mc_buffer_header_t *scratch_header =
task->allgather_bruck.scratch_header;
size_t scratch_size = task->allgather_bruck.scratch_size;
size_t data_size = (count / tsize) * ucc_dt_size(dt);
ucc_rank_t recvfrom, sendto;
ucc_status_t status;
size_t blockcount, distance;
void *tmprecv, *tmpsend;

if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) {
return;
}

/* On each step doubles distance */
distance = 1 << task->tagged.recv_posted;
tmpsend = rbuf;
while (distance < tsize) {

recvfrom = (trank + distance) % tsize;
sendto = (trank + tsize - distance) % tsize;

tmprecv = PTR_OFFSET(tmpsend, distance * data_size);

if (distance <= tsize >> 1) {
blockcount = distance;
} else {
/* send-recv all reminder */
blockcount = tsize - distance;
}

/* Sendreceive */
UCPCHECK_GOTO(ucc_tl_ucp_send_nb(tmpsend, blockcount * data_size, rmem,
sendto, team, task),
task, out);
UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(tmprecv, blockcount * data_size, rmem,
ikryukov marked this conversation as resolved.
Show resolved Hide resolved
recvfrom, team, task),
task, out);

if (UCC_INPROGRESS == ucc_tl_ucp_test_recv(task)) {
return;
}

distance = 1 << task->tagged.recv_posted;
}

if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) {
return;
}

/* post processing step */
if (trank != 0) {
if (UCC_MEMORY_TYPE_HOST == rmem) {
// copy blocks [0 .. (size - rank - 1)] from rbuf to shift buffer
status = ucc_mc_memcpy(scratch_header->addr, rbuf, scratch_size,
UCC_MEMORY_TYPE_HOST, rmem);
if (ucc_unlikely(status != UCC_OK)) {
tl_error(UCC_TASK_LIB(task),
"failed to copy data to scratch buffer");
task->super.status = status;
return;
}
// move blocks [(size - rank) .. size] from rbuf to beginning of rbuf
// TODO: rewrite to cycle to get rid of overlap
memmove(rbuf, PTR_OFFSET(rbuf, scratch_size), trank * data_size);
// copy blocks from shift buffer starting at block [rank] in rbuf.
status = ucc_mc_memcpy(PTR_OFFSET(rbuf, trank * data_size),
scratch_header->addr, scratch_size, rmem,
UCC_MEMORY_TYPE_HOST);
if (ucc_unlikely(status != UCC_OK)) {
tl_error(UCC_TASK_LIB(task),
"failed to copy data from scratch to rbuff buffer");
task->super.status = status;
return;
}
} else {
/* In case of non host memory we perform two copy to host buffer and then back to device, 3 memcopy in total */
/* TODO: replace with generic kernel to do bruck post step in sinle launch on device */
status = ucc_mc_memcpy(
ikryukov marked this conversation as resolved.
Show resolved Hide resolved
PTR_OFFSET(scratch_header->addr, trank * data_size), rbuf,
(tsize - trank) * data_size, UCC_MEMORY_TYPE_HOST, rmem);
samnordmann marked this conversation as resolved.
Show resolved Hide resolved
if (ucc_unlikely(status != UCC_OK)) {
tl_error(UCC_TASK_LIB(task),
"failed to copy first data part to scratch buffer");
task->super.status = status;
return;
}
status =
ucc_mc_memcpy(scratch_header->addr,
PTR_OFFSET(rbuf, (tsize - trank) * data_size),
trank * data_size, UCC_MEMORY_TYPE_HOST, rmem);
if (ucc_unlikely(status != UCC_OK)) {
tl_error(UCC_TASK_LIB(task),
"failed to copy second data part to scratch buffer");
task->super.status = status;
return;
}
status =
ucc_mc_memcpy(rbuf, scratch_header->addr, tsize * data_size,
rmem, UCC_MEMORY_TYPE_HOST);
if (ucc_unlikely(status != UCC_OK)) {
tl_error(UCC_TASK_LIB(task),
"failed to copy from scratch buffer to dst");
task->super.status = status;
return;
ikryukov marked this conversation as resolved.
Show resolved Hide resolved
}
}
}

ucc_assert(UCC_TL_UCP_TASK_P2P_COMPLETE(task));

task->super.status = UCC_OK;

out:
UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_bruck_done", 0);
}

ucc_status_t ucc_tl_ucp_allgather_bruck_start(ucc_coll_task_t *coll_task)
{
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
size_t count = TASK_ARGS(task).dst.info.count;
void *sbuf = TASK_ARGS(task).src.info.buffer;
void *rbuf = TASK_ARGS(task).dst.info.buffer;
ucc_memory_type_t smem = TASK_ARGS(task).src.info.mem_type;
ucc_memory_type_t rmem = TASK_ARGS(task).dst.info.mem_type;
ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype;
ucc_rank_t trank = UCC_TL_TEAM_RANK(team);
ucc_rank_t tsize = UCC_TL_TEAM_SIZE(team);
size_t data_size = (count / tsize) * ucc_dt_size(dt);
ucc_status_t status;

UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_bruck_start", 0);
ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);

/* initial step: copy data on non root ranks to the beginning of buffer */
if (!UCC_IS_INPLACE(TASK_ARGS(task))) {
// not inplace: copy chunk from source buff to beginning of receive
status = ucc_mc_memcpy(rbuf, sbuf, data_size, rmem, smem);
if (ucc_unlikely(UCC_OK != status)) {
return status;
}
} else if (trank != 0) {
// inplace: copy chunk to the begin
status = ucc_mc_memcpy(rbuf, PTR_OFFSET(rbuf, data_size * trank),
data_size, rmem, rmem);
if (ucc_unlikely(UCC_OK != status)) {
return status;
}
}

return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
}
3 changes: 3 additions & 0 deletions src/components/tl/ucp/tl_ucp_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,9 @@ ucc_status_t ucc_tl_ucp_alg_id_to_init(int alg_id, const char *alg_id_str,
case UCC_TL_UCP_ALLGATHER_ALG_NEIGHBOR:
*init = ucc_tl_ucp_allgather_neighbor_init;
break;
case UCC_TL_UCP_ALLGATHER_ALG_BRUCK:
*init = ucc_tl_ucp_allgather_bruck_init;
break;
default:
status = UCC_ERR_INVALID_PARAM;
break;
Expand Down
4 changes: 4 additions & 0 deletions src/components/tl/ucp/tl_ucp_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,10 @@ typedef struct ucc_tl_ucp_task {
ucc_rank_t tsize,
int step);
} allgather_ring;
struct {
ucc_mc_buffer_header_t *scratch_header;
size_t scratch_size;
} allgather_bruck;
samnordmann marked this conversation as resolved.
Show resolved Hide resolved
struct {
ucc_rank_t dist;
uint32_t radix;
Expand Down
2 changes: 1 addition & 1 deletion test/gtest/coll/test_allgather.cc
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ INSTANTIATE_TEST_CASE_P(
#endif
::testing::Values(1,3,8192), // count
::testing::Values(TEST_INPLACE, TEST_NO_INPLACE),
::testing::Values("knomial", "ring", "neighbor")),
::testing::Values("knomial", "ring", "neighbor", "bruck")),
[](const testing::TestParamInfo<test_allgather_alg::ParamType>& info) {
std::string name;
name += ucc_datatype_str(std::get<0>(info.param));
Expand Down
Loading