From d16c45db2565a25883100162630a992e9319c17c Mon Sep 17 00:00:00 2001 From: Ilya Kryukov Date: Tue, 16 Jan 2024 11:38:56 +0100 Subject: [PATCH] TL/UCP: Bruck algorithm initial --- src/components/tl/ucp/Makefile.am | 1 + src/components/tl/ucp/allgather/allgather.c | 8 +- src/components/tl/ucp/allgather/allgather.h | 10 + .../tl/ucp/allgather/allgather_bruck.c | 197 ++++++++++++++++++ src/components/tl/ucp/tl_ucp_coll.c | 3 + test/gtest/coll/test_allgather.cc | 2 +- 6 files changed, 219 insertions(+), 2 deletions(-) create mode 100644 src/components/tl/ucp/allgather/allgather_bruck.c diff --git a/src/components/tl/ucp/Makefile.am b/src/components/tl/ucp/Makefile.am index badc741d99..fb04c61f14 100644 --- a/src/components/tl/ucp/Makefile.am +++ b/src/components/tl/ucp/Makefile.am @@ -14,6 +14,7 @@ allgather = \ allgather/allgather.c \ allgather/allgather_ring.c \ allgather/allgather_neighbor.c \ + allgather/allgather_bruck.c \ allgather/allgather_knomial.c allgatherv = \ diff --git a/src/components/tl/ucp/allgather/allgather.c b/src/components/tl/ucp/allgather/allgather.c index 926b732e55..96309daa86 100644 --- a/src/components/tl/ucp/allgather/allgather.c +++ b/src/components/tl/ucp/allgather/allgather.c @@ -23,11 +23,16 @@ ucc_base_coll_alg_info_t {.id = UCC_TL_UCP_ALLGATHER_ALG_NEIGHBOR, .name = "neighbor", .desc = "O(N) Neighbor Exchange N/2 steps"}, + [UCC_TL_UCP_ALLGATHER_ALG_BRUCK] = + {.id = UCC_TL_UCP_ALLGATHER_ALG_BRUCK, + .name = "bruck", + .desc = "O(log(N)) Variation of Bruck algorithm"}, [UCC_TL_UCP_ALLGATHER_ALG_LAST] = { .id = 0, .name = NULL, .desc = NULL}}; ucc_status_t ucc_tl_ucp_allgather_init(ucc_tl_ucp_task_t *task) { + printf ("HELLO\n"); return ucc_tl_ucp_allgather_ring_init_common(task); } @@ -36,7 +41,7 @@ char *ucc_tl_ucp_allgather_score_str_get(ucc_tl_ucp_team_t *team) int max_size = ALLGATHER_MAX_PATTERN_SIZE; int algo_num = UCC_TL_TEAM_SIZE(team) % 2 ? UCC_TL_UCP_ALLGATHER_ALG_RING - : UCC_TL_UCP_ALLGATHER_ALG_NEIGHBOR; + : UCC_TL_UCP_ALLGATHER_ALG_BRUCK; char *str = ucc_malloc(max_size * sizeof(char)); ucc_sbgp_t *sbgp; @@ -46,6 +51,7 @@ char *ucc_tl_ucp_allgather_score_str_get(ucc_tl_ucp_team_t *team) algo_num = UCC_TL_UCP_ALLGATHER_ALG_RING; } } + fprintf(stderr, "Algo num: %d\n", algo_num); ucc_snprintf_safe(str, max_size, UCC_TL_UCP_ALLGATHER_DEFAULT_ALG_SELECT_STR, algo_num); return str; diff --git a/src/components/tl/ucp/allgather/allgather.h b/src/components/tl/ucp/allgather/allgather.h index b68ab00e95..0e816e8838 100644 --- a/src/components/tl/ucp/allgather/allgather.h +++ b/src/components/tl/ucp/allgather/allgather.h @@ -12,6 +12,7 @@ enum { UCC_TL_UCP_ALLGATHER_ALG_KNOMIAL, UCC_TL_UCP_ALLGATHER_ALG_RING, UCC_TL_UCP_ALLGATHER_ALG_NEIGHBOR, + UCC_TL_UCP_ALLGATHER_ALG_BRUCK, UCC_TL_UCP_ALLGATHER_ALG_LAST }; @@ -56,6 +57,15 @@ void ucc_tl_ucp_allgather_neighbor_progress(ucc_coll_task_t *task); ucc_status_t ucc_tl_ucp_allgather_neighbor_start(ucc_coll_task_t *task); +/* Bruck */ +ucc_status_t ucc_tl_ucp_allgather_bruck_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team, + ucc_coll_task_t **task_h); + +void ucc_tl_ucp_allgather_bruck_progress(ucc_coll_task_t *task); + +ucc_status_t ucc_tl_ucp_allgather_bruck_start(ucc_coll_task_t *task); + /* Uses allgather_kn_radix from config */ ucc_status_t ucc_tl_ucp_allgather_knomial_init(ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, diff --git a/src/components/tl/ucp/allgather/allgather_bruck.c b/src/components/tl/ucp/allgather/allgather_bruck.c new file mode 100644 index 0000000000..052f12802c --- /dev/null +++ b/src/components/tl/ucp/allgather/allgather_bruck.c @@ -0,0 +1,197 @@ +/** + * Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ +#include "config.h" +#include "tl_ucp.h" +#include "allgather.h" +#include "core/ucc_progress_queue.h" +#include "tl_ucp_sendrecv.h" +#include "utils/ucc_math.h" +#include "utils/ucc_coll_utils.h" +#include "components/mc/ucc_mc.h" +#include + +ucc_status_t ucc_tl_ucp_allgather_bruck_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team, + ucc_coll_task_t **task_h) +{ + ucc_status_t status = UCC_OK; + ucc_tl_ucp_task_t *task; + ucc_tl_ucp_team_t *ucp_team; + + task = ucc_tl_ucp_init_task(coll_args, team); + ucp_team = TASK_TEAM(task); + + if (!ucc_coll_args_is_predefined_dt(&TASK_ARGS(task), UCC_RANK_INVALID)) { + tl_error(UCC_TASK_LIB(task), "user defined datatype is not supported"); + status = UCC_ERR_NOT_SUPPORTED; + goto out; + } + printf("ucc_tl_ucp_allgather_bruck_init\n"); + if (UCC_TL_TEAM_SIZE(ucp_team) % 2) { + tl_debug(UCC_TASK_LIB(task), + "odd team size is not supported, switching to ring"); + status = ucc_tl_ucp_allgather_ring_init_common(task); + } else { + task->super.post = ucc_tl_ucp_allgather_bruck_start; + task->super.progress = ucc_tl_ucp_allgather_bruck_progress; + } + +out: + if (status != UCC_OK) { + ucc_tl_ucp_put_task(task); + return status; + } + + *task_h = &task->super; + return status; +} + +/* Inspired by implementation: https://github.com/open-mpi/ompi/blob/main/ompi/mca/coll/base/coll_base_allgather.c */ +void ucc_tl_ucp_allgather_bruck_progress(ucc_coll_task_t *coll_task) +{ + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + ucc_tl_ucp_team_t *team = TASK_TEAM(task); + ucc_rank_t trank = UCC_TL_TEAM_RANK(team); + ucc_rank_t tsize = UCC_TL_TEAM_SIZE(team); + void *rbuf = TASK_ARGS(task).dst.info.buffer; + ucc_memory_type_t rmem = TASK_ARGS(task).dst.info.mem_type; + ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype; + size_t count = TASK_ARGS(task).dst.info.count; + size_t data_size = (count / tsize) * ucc_dt_size(dt); + ucc_rank_t recvfrom, sendto; + ucc_status_t status; + size_t blockcount, distance; + void *tmprecv, *tmpsend; + + if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) { + return; + } + + /* On each step doubles distance */ + distance = 1 << task->tagged.send_posted; + printf("bruck\n"); + tmpsend = rbuf; + while (distance < (tsize)) { + + recvfrom = (trank + distance) % tsize; + sendto = (trank + tsize - distance) % tsize; + + tmprecv = PTR_OFFSET(tmpsend, distance * data_size); + + if (distance <= tsize >> 1) { + blockcount = distance; + } else { + /* send-recv all reminder*/ + blockcount = tsize - distance; + } + + /* Sendreceive */ + UCPCHECK_GOTO(ucc_tl_ucp_send_nb(tmpsend, blockcount * data_size, rmem, + sendto, team, task), + task, out); + UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(tmprecv, blockcount * data_size, rmem, + recvfrom, team, task), + task, out); + + if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) { + return; + } + } + + /* post processing step */ + if (trank != 0) { + ucc_mc_buffer_header_t *scratch_header; + size_t scratch_size = (tsize - trank) * data_size; + /* allocate scratch buffer */ + status = + ucc_mc_alloc(&scratch_header, scratch_size, UCC_MEMORY_TYPE_HOST); + if (ucc_unlikely(status != UCC_OK)) { + tl_error(UCC_TASK_LIB(task), "failed to allocate scratch buffer"); + ucc_tl_ucp_coll_finalize(&task->super); + return; + } + + status = ucc_mc_memcpy(scratch_header->addr, rbuf, scratch_size, + UCC_MEMORY_TYPE_HOST, UCC_MEMORY_TYPE_HOST); + if (ucc_unlikely(status != UCC_OK)) { + tl_error(UCC_TASK_LIB(task), + "failed to copy data to scratch buffer"); + ucc_tl_ucp_coll_finalize(&task->super); + return; + } + + status = ucc_mc_memcpy(rbuf, PTR_OFFSET(rbuf, scratch_size), + trank * data_size, UCC_MEMORY_TYPE_HOST, + UCC_MEMORY_TYPE_HOST); + if (ucc_unlikely(status != UCC_OK)) { + tl_error(UCC_TASK_LIB(task), + "failed to move data inside rbuff buffer"); + ucc_tl_ucp_coll_finalize(&task->super); + return; + } + + status = ucc_mc_memcpy(PTR_OFFSET(rbuf, trank * data_size), + scratch_header->addr, scratch_size, + UCC_MEMORY_TYPE_HOST, UCC_MEMORY_TYPE_HOST); + if (ucc_unlikely(status != UCC_OK)) { + tl_error(UCC_TASK_LIB(task), + "failed to copy data from scratch to rbuff buffer"); + ucc_tl_ucp_coll_finalize(&task->super); + return; + } + + /* deallocate scratch buffer */ + status = ucc_mc_free(scratch_header); + if (ucc_unlikely(status != UCC_OK)) { + tl_error(UCC_TASK_LIB(task), + "failed to free scratch buffer memory"); + ucc_tl_ucp_coll_finalize(&task->super); + return; + } + } + + ucc_assert(UCC_TL_UCP_TASK_P2P_COMPLETE(task)); + task->super.status = UCC_OK; + +out: + UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_bruck_done", 0); +} + +ucc_status_t ucc_tl_ucp_allgather_bruck_start(ucc_coll_task_t *coll_task) +{ + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + ucc_tl_ucp_team_t *team = TASK_TEAM(task); + size_t count = TASK_ARGS(task).dst.info.count; + void *sbuf = TASK_ARGS(task).src.info.buffer; + void *rbuf = TASK_ARGS(task).dst.info.buffer; + ucc_memory_type_t smem = TASK_ARGS(task).src.info.mem_type; + ucc_memory_type_t rmem = TASK_ARGS(task).dst.info.mem_type; + ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype; + ucc_rank_t trank = UCC_TL_TEAM_RANK(team); + ucc_rank_t tsize = UCC_TL_TEAM_SIZE(team); + size_t data_size = (count / tsize) * ucc_dt_size(dt); + ucc_status_t status; + + UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_bruck_start", 0); + ucc_tl_ucp_task_reset(task, UCC_INPROGRESS); + + /* initial step: copy data on non root ranks to the beginning of buffer */ + if (!UCC_IS_INPLACE(TASK_ARGS(task))) { + status = ucc_mc_memcpy(rbuf, PTR_OFFSET(sbuf, data_size * trank), + data_size, rmem, smem); + if (ucc_unlikely(UCC_OK != status)) { + return status; + } + } else if (trank != 0) { + status = ucc_mc_memcpy(rbuf, PTR_OFFSET(rbuf, data_size * trank), + data_size, rmem, rmem); + if (ucc_unlikely(UCC_OK != status)) { + return status; + } + } + + return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); +} diff --git a/src/components/tl/ucp/tl_ucp_coll.c b/src/components/tl/ucp/tl_ucp_coll.c index a1ba843d9f..200b63b40d 100644 --- a/src/components/tl/ucp/tl_ucp_coll.c +++ b/src/components/tl/ucp/tl_ucp_coll.c @@ -255,6 +255,9 @@ ucc_status_t ucc_tl_ucp_alg_id_to_init(int alg_id, const char *alg_id_str, case UCC_TL_UCP_ALLGATHER_ALG_NEIGHBOR: *init = ucc_tl_ucp_allgather_neighbor_init; break; + case UCC_TL_UCP_ALLGATHER_ALG_BRUCK: + *init = ucc_tl_ucp_allgather_bruck_init; + break; default: status = UCC_ERR_INVALID_PARAM; break; diff --git a/test/gtest/coll/test_allgather.cc b/test/gtest/coll/test_allgather.cc index e1cacac5ac..2577bdf26c 100644 --- a/test/gtest/coll/test_allgather.cc +++ b/test/gtest/coll/test_allgather.cc @@ -296,7 +296,7 @@ INSTANTIATE_TEST_CASE_P( #endif ::testing::Values(1,3,8192), // count ::testing::Values(TEST_INPLACE, TEST_NO_INPLACE), - ::testing::Values("knomial", "ring", "neighbor")), + ::testing::Values("knomial", "ring", "neighbor", "bruck")), [](const testing::TestParamInfo& info) { std::string name; name += ucc_datatype_str(std::get<0>(info.param));