Skip to content

Commit

Permalink
TL/UCP: Bruck algorithm initial
Browse files Browse the repository at this point in the history
  • Loading branch information
ikryukov committed Jan 16, 2024
1 parent 8c31870 commit d16c45d
Show file tree
Hide file tree
Showing 6 changed files with 219 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/components/tl/ucp/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ allgather = \
allgather/allgather.c \
allgather/allgather_ring.c \
allgather/allgather_neighbor.c \
allgather/allgather_bruck.c \
allgather/allgather_knomial.c

allgatherv = \
Expand Down
8 changes: 7 additions & 1 deletion src/components/tl/ucp/allgather/allgather.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,16 @@ ucc_base_coll_alg_info_t
{.id = UCC_TL_UCP_ALLGATHER_ALG_NEIGHBOR,
.name = "neighbor",
.desc = "O(N) Neighbor Exchange N/2 steps"},
[UCC_TL_UCP_ALLGATHER_ALG_BRUCK] =
{.id = UCC_TL_UCP_ALLGATHER_ALG_BRUCK,
.name = "bruck",
.desc = "O(log(N)) Variation of Bruck algorithm"},
[UCC_TL_UCP_ALLGATHER_ALG_LAST] = {
.id = 0, .name = NULL, .desc = NULL}};

ucc_status_t ucc_tl_ucp_allgather_init(ucc_tl_ucp_task_t *task)
{
printf ("HELLO\n");
return ucc_tl_ucp_allgather_ring_init_common(task);
}

Expand All @@ -36,7 +41,7 @@ char *ucc_tl_ucp_allgather_score_str_get(ucc_tl_ucp_team_t *team)
int max_size = ALLGATHER_MAX_PATTERN_SIZE;
int algo_num = UCC_TL_TEAM_SIZE(team) % 2
? UCC_TL_UCP_ALLGATHER_ALG_RING
: UCC_TL_UCP_ALLGATHER_ALG_NEIGHBOR;
: UCC_TL_UCP_ALLGATHER_ALG_BRUCK;
char *str = ucc_malloc(max_size * sizeof(char));
ucc_sbgp_t *sbgp;

Expand All @@ -46,6 +51,7 @@ char *ucc_tl_ucp_allgather_score_str_get(ucc_tl_ucp_team_t *team)
algo_num = UCC_TL_UCP_ALLGATHER_ALG_RING;
}
}
fprintf(stderr, "Algo num: %d\n", algo_num);
ucc_snprintf_safe(str, max_size,
UCC_TL_UCP_ALLGATHER_DEFAULT_ALG_SELECT_STR, algo_num);
return str;
Expand Down
10 changes: 10 additions & 0 deletions src/components/tl/ucp/allgather/allgather.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ enum {
UCC_TL_UCP_ALLGATHER_ALG_KNOMIAL,
UCC_TL_UCP_ALLGATHER_ALG_RING,
UCC_TL_UCP_ALLGATHER_ALG_NEIGHBOR,
UCC_TL_UCP_ALLGATHER_ALG_BRUCK,
UCC_TL_UCP_ALLGATHER_ALG_LAST
};

Expand Down Expand Up @@ -56,6 +57,15 @@ void ucc_tl_ucp_allgather_neighbor_progress(ucc_coll_task_t *task);

ucc_status_t ucc_tl_ucp_allgather_neighbor_start(ucc_coll_task_t *task);

/* Bruck */
ucc_status_t ucc_tl_ucp_allgather_bruck_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h);

void ucc_tl_ucp_allgather_bruck_progress(ucc_coll_task_t *task);

ucc_status_t ucc_tl_ucp_allgather_bruck_start(ucc_coll_task_t *task);

/* Uses allgather_kn_radix from config */
ucc_status_t ucc_tl_ucp_allgather_knomial_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
Expand Down
197 changes: 197 additions & 0 deletions src/components/tl/ucp/allgather/allgather_bruck.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
/**
* Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
#include "config.h"
#include "tl_ucp.h"
#include "allgather.h"
#include "core/ucc_progress_queue.h"
#include "tl_ucp_sendrecv.h"
#include "utils/ucc_math.h"
#include "utils/ucc_coll_utils.h"
#include "components/mc/ucc_mc.h"
#include <stdio.h>

ucc_status_t ucc_tl_ucp_allgather_bruck_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h)
{
ucc_status_t status = UCC_OK;
ucc_tl_ucp_task_t *task;
ucc_tl_ucp_team_t *ucp_team;

task = ucc_tl_ucp_init_task(coll_args, team);
ucp_team = TASK_TEAM(task);

if (!ucc_coll_args_is_predefined_dt(&TASK_ARGS(task), UCC_RANK_INVALID)) {
tl_error(UCC_TASK_LIB(task), "user defined datatype is not supported");
status = UCC_ERR_NOT_SUPPORTED;
goto out;
}
printf("ucc_tl_ucp_allgather_bruck_init\n");
if (UCC_TL_TEAM_SIZE(ucp_team) % 2) {
tl_debug(UCC_TASK_LIB(task),
"odd team size is not supported, switching to ring");
status = ucc_tl_ucp_allgather_ring_init_common(task);
} else {
task->super.post = ucc_tl_ucp_allgather_bruck_start;
task->super.progress = ucc_tl_ucp_allgather_bruck_progress;
}

out:
if (status != UCC_OK) {
ucc_tl_ucp_put_task(task);
return status;
}

*task_h = &task->super;
return status;
}

/* Inspired by implementation: https://github.com/open-mpi/ompi/blob/main/ompi/mca/coll/base/coll_base_allgather.c */
void ucc_tl_ucp_allgather_bruck_progress(ucc_coll_task_t *coll_task)
{
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_rank_t trank = UCC_TL_TEAM_RANK(team);
ucc_rank_t tsize = UCC_TL_TEAM_SIZE(team);
void *rbuf = TASK_ARGS(task).dst.info.buffer;
ucc_memory_type_t rmem = TASK_ARGS(task).dst.info.mem_type;
ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype;
size_t count = TASK_ARGS(task).dst.info.count;
size_t data_size = (count / tsize) * ucc_dt_size(dt);
ucc_rank_t recvfrom, sendto;
ucc_status_t status;
size_t blockcount, distance;
void *tmprecv, *tmpsend;

if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) {
return;
}

/* On each step doubles distance */
distance = 1 << task->tagged.send_posted;
printf("bruck\n");
tmpsend = rbuf;
while (distance < (tsize)) {

recvfrom = (trank + distance) % tsize;
sendto = (trank + tsize - distance) % tsize;

tmprecv = PTR_OFFSET(tmpsend, distance * data_size);

if (distance <= tsize >> 1) {
blockcount = distance;
} else {
/* send-recv all reminder*/
blockcount = tsize - distance;
}

/* Sendreceive */
UCPCHECK_GOTO(ucc_tl_ucp_send_nb(tmpsend, blockcount * data_size, rmem,
sendto, team, task),
task, out);
UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(tmprecv, blockcount * data_size, rmem,
recvfrom, team, task),
task, out);

if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) {
return;
}
}

/* post processing step */
if (trank != 0) {
ucc_mc_buffer_header_t *scratch_header;
size_t scratch_size = (tsize - trank) * data_size;
/* allocate scratch buffer */
status =
ucc_mc_alloc(&scratch_header, scratch_size, UCC_MEMORY_TYPE_HOST);
if (ucc_unlikely(status != UCC_OK)) {
tl_error(UCC_TASK_LIB(task), "failed to allocate scratch buffer");
ucc_tl_ucp_coll_finalize(&task->super);
return;
}

status = ucc_mc_memcpy(scratch_header->addr, rbuf, scratch_size,
UCC_MEMORY_TYPE_HOST, UCC_MEMORY_TYPE_HOST);
if (ucc_unlikely(status != UCC_OK)) {
tl_error(UCC_TASK_LIB(task),
"failed to copy data to scratch buffer");
ucc_tl_ucp_coll_finalize(&task->super);
return;
}

status = ucc_mc_memcpy(rbuf, PTR_OFFSET(rbuf, scratch_size),
trank * data_size, UCC_MEMORY_TYPE_HOST,
UCC_MEMORY_TYPE_HOST);
if (ucc_unlikely(status != UCC_OK)) {
tl_error(UCC_TASK_LIB(task),
"failed to move data inside rbuff buffer");
ucc_tl_ucp_coll_finalize(&task->super);
return;
}

status = ucc_mc_memcpy(PTR_OFFSET(rbuf, trank * data_size),
scratch_header->addr, scratch_size,
UCC_MEMORY_TYPE_HOST, UCC_MEMORY_TYPE_HOST);
if (ucc_unlikely(status != UCC_OK)) {
tl_error(UCC_TASK_LIB(task),
"failed to copy data from scratch to rbuff buffer");
ucc_tl_ucp_coll_finalize(&task->super);
return;
}

/* deallocate scratch buffer */
status = ucc_mc_free(scratch_header);
if (ucc_unlikely(status != UCC_OK)) {
tl_error(UCC_TASK_LIB(task),
"failed to free scratch buffer memory");
ucc_tl_ucp_coll_finalize(&task->super);
return;
}
}

ucc_assert(UCC_TL_UCP_TASK_P2P_COMPLETE(task));
task->super.status = UCC_OK;

out:
UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_bruck_done", 0);
}

ucc_status_t ucc_tl_ucp_allgather_bruck_start(ucc_coll_task_t *coll_task)
{
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
size_t count = TASK_ARGS(task).dst.info.count;
void *sbuf = TASK_ARGS(task).src.info.buffer;
void *rbuf = TASK_ARGS(task).dst.info.buffer;
ucc_memory_type_t smem = TASK_ARGS(task).src.info.mem_type;
ucc_memory_type_t rmem = TASK_ARGS(task).dst.info.mem_type;
ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype;
ucc_rank_t trank = UCC_TL_TEAM_RANK(team);
ucc_rank_t tsize = UCC_TL_TEAM_SIZE(team);
size_t data_size = (count / tsize) * ucc_dt_size(dt);
ucc_status_t status;

UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_bruck_start", 0);
ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);

/* initial step: copy data on non root ranks to the beginning of buffer */
if (!UCC_IS_INPLACE(TASK_ARGS(task))) {
status = ucc_mc_memcpy(rbuf, PTR_OFFSET(sbuf, data_size * trank),
data_size, rmem, smem);
if (ucc_unlikely(UCC_OK != status)) {
return status;
}
} else if (trank != 0) {
status = ucc_mc_memcpy(rbuf, PTR_OFFSET(rbuf, data_size * trank),
data_size, rmem, rmem);
if (ucc_unlikely(UCC_OK != status)) {
return status;
}
}

return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
}
3 changes: 3 additions & 0 deletions src/components/tl/ucp/tl_ucp_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,9 @@ ucc_status_t ucc_tl_ucp_alg_id_to_init(int alg_id, const char *alg_id_str,
case UCC_TL_UCP_ALLGATHER_ALG_NEIGHBOR:
*init = ucc_tl_ucp_allgather_neighbor_init;
break;
case UCC_TL_UCP_ALLGATHER_ALG_BRUCK:
*init = ucc_tl_ucp_allgather_bruck_init;
break;
default:
status = UCC_ERR_INVALID_PARAM;
break;
Expand Down
2 changes: 1 addition & 1 deletion test/gtest/coll/test_allgather.cc
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ INSTANTIATE_TEST_CASE_P(
#endif
::testing::Values(1,3,8192), // count
::testing::Values(TEST_INPLACE, TEST_NO_INPLACE),
::testing::Values("knomial", "ring", "neighbor")),
::testing::Values("knomial", "ring", "neighbor", "bruck")),
[](const testing::TestParamInfo<test_allgather_alg::ParamType>& info) {
std::string name;
name += ucc_datatype_str(std::get<0>(info.param));
Expand Down

0 comments on commit d16c45d

Please sign in to comment.