diff --git a/config/m4/rocm.m4 b/config/m4/rocm.m4 index 356961939f..c7c4ca141d 100644 --- a/config/m4/rocm.m4 +++ b/config/m4/rocm.m4 @@ -39,11 +39,28 @@ AC_DEFUN([ROCM_BUILD_FLAGS], # Parse value of ARG into appropriate LIBS, LDFLAGS, and # CPPFLAGS variables. AC_DEFUN([HIP_BUILD_FLAGS], - $4="-D__HIP_PLATFORM_AMD__ -I$1/include/hip -I$1/include" - $3="-L$1/lib" + $4="-D__HIP_PLATFORM_AMD__ -I$1/include/hip -I$1/include -I$1/llvm/include" + $3="-L$1/lib -L$1/llvm/lib" $2="-lamdhip64" ) +# CHECK_ROCM_VERSION(HIP_VERSION_MAJOR, ROCM_VERSION_CONDITION) +# ---------------------------------------------------------- +# Checks ROCm version and marks condition as 1 (TRUE) or 0 (FALSE) +AC_DEFUN([CHECK_ROCM_VERSION], [ +AC_COMPILE_IFELSE( +[AC_LANG_PROGRAM([[#include <${with_rocm}/include/hip/hip_version.h> + ]], [[ +#if HIP_VERSION_MAJOR >= $1 +return 0; +#else +intr make+compilation_fail() +#endif + ]])], + [$2=1], + [$2=0]) +]) + # # Check for ROCm support # @@ -102,28 +119,25 @@ AS_IF([test "x$with_rocm" != "xno"], LDFLAGS="$SAVE_LDFLAGS" LIBS="$SAVE_LIBS" - #Check whether we run on ROCm 5.0 or higher - AC_COMPILE_IFELSE( - [AC_LANG_PROGRAM([[#include <${with_rocm}/include/rocm_version.h> - ]], [[ -#if ROCM_VERSION_MAJOR >= 5 -return 0; -#else -intr make+compilation_fail() -#endif - ]])], - [ROCM_VERSION_50_OR_GREATER=1], - [ROCM_VERSION_50_OR_GREATER=0]) - - HIP_BUILD_FLAGS([$with_rocm], [HIP_LIBS], [HIP_LDFLAGS], [HIP_CPPFLAGS]) - AC_MSG_CHECKING([if ROCm version is 5.0 or above]) - if test "$ROCM_VERSION_50_OR_GREATER" = "1" ; then + + # Check whether we run on ROCm 6.0 or higher + CHECK_ROCM_VERSION(6, ROCM_VERSION_60_OR_GREATER) + AC_MSG_CHECKING([if ROCm version is 6.0 or above]) + if test "$ROCM_VERSION_60_OR_GREATER" = "1" ; then AC_MSG_RESULT([yes]) else AC_MSG_RESULT([no]) - HIP_CPPFLAGS="${HIP_CPPFLAGS} -I${with_rocm}/hip/include" - HIP_LDFLAGS="${HIP_LDFLAGS} -L${with_rocm}/hip/lib" + # Check whether we run on ROCm 5.0-5.7 + CHECK_ROCM_VERSION(5, ROCM_VERSION_50_OR_GREATER) + AC_MSG_CHECKING([if ROCm version is 5.0 - 5.7]) + if test "$ROCM_VERSION_50_OR_GREATER" = "1" ; then + AC_MSG_RESULT([yes]) + else + AC_MSG_RESULT([no]) + HIP_CPPFLAGS="${HIP_CPPFLAGS} -I${with_rocm}/hip/include" + HIP_LDFLAGS="${HIP_LDFLAGS} -L${with_rocm}/hip/lib" + fi fi CPPFLAGS="$HIP_CPPFLAGS $CPPFLAGS" @@ -142,10 +156,17 @@ intr make+compilation_fail() LDFLAGS="$SAVE_LDFLAGS" LIBS="$SAVE_LIBS" - - AS_IF([test "x$hip_happy" = "xyes"], - [AC_PATH_PROG([HIPCC], [hipcc], [notfound], [$PATH:$with_rocm/bin])]) - AS_IF([test "$HIPCC" = "notfound"], [hip_happy="no"]) + if test "$ROCM_VERSION_60_OR_GREATER" = "1" ; then + AC_MSG_NOTICE([using amdclang as ROCm version is 6.0 or above]) + AS_IF([test "x$hip_happy" = "xyes"], + [AC_PATH_PROG([HIPCC], [amdclang], [notfound], [$PATH:$with_rocm/bin])]) + AS_IF([test "$HIPCC" = "notfound"], [hip_happy="no"]) + else + AC_MSG_NOTICE([using hipcc as ROCm version is 3.7.0 to ROCm 5.7.1]) + AS_IF([test "x$hip_happy" = "xyes"], + [AC_PATH_PROG([HIPCC], [hipcc], [notfound], [$PATH:$with_rocm/bin])]) + AS_IF([test "$HIPCC" = "notfound"], [hip_happy="no"]) + fi AS_IF([test "x$hip_happy" = "xyes"], [AC_DEFINE([HAVE_HIP], 1, [Enable HIP support]) diff --git a/cuda_lt.sh b/cuda_lt.sh index 68f9b5ff60..6601e3edcb 100755 --- a/cuda_lt.sh +++ b/cuda_lt.sh @@ -27,7 +27,9 @@ local_npic_filepath="${local_npic_dir}${o_filename}" mkdir -p $pic_dir tmpcmd="${@:3}" -if [[ "$tmpcmd" == *"hipcc"* ]]; then +if [[ "$tmpcmd" == *"amdclang"* ]]; then + cmd="${@:3:2} -x hip -target x86_64-unknown-linux-gnu --offload-arch=gfx908 --offload-arch=gfx90a --offload-arch=gfx940 --offload-arch=gfx941 --offload-arch=gfx942 --offload-arch=gfx1030 --offload-arch=gfx1100 --offload-arch=gfx1101 --offload-arch=gfx1102 --offload-arch=native ${@:5} -fPIC -O3 -o ${pic_filepath}" +elif [[ "$tmpcmd" == *"hipcc"* ]]; then cmd="${@:3} -fPIC -o ${pic_filepath}" else cmd="${@:3} -Xcompiler -fPIC -o ${pic_filepath}" @@ -35,7 +37,11 @@ fi echo $cmd $cmd -cmd="${@:3} -o ${npic_filepath}" +if [[ "$tmpcmd" == *"amdclang"* ]]; then + cmd="${@:3:2} -x hip -target x86_64-unknown-linux-gnu --offload-arch=gfx908 --offload-arch=gfx90a --offload-arch=gfx940 --offload-arch=gfx941 --offload-arch=gfx942 --offload-arch=gfx1030 --offload-arch=gfx1100 --offload-arch=gfx1101 --offload-arch=gfx1102 --offload-arch=native ${@:5} -O3 -o ${npic_filepath}" +else + cmd="${@:3} -o ${npic_filepath}" +fi echo $cmd $cmd diff --git a/src/coll_score/ucc_coll_score.c b/src/coll_score/ucc_coll_score.c index 7cc4f90af3..c99d33f9dc 100644 --- a/src/coll_score/ucc_coll_score.c +++ b/src/coll_score/ucc_coll_score.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -9,6 +9,16 @@ #include "utils/ucc_log.h" #include "utils/ucc_coll_utils.h" +char *ucc_score_to_str(ucc_score_t score, char *buf, size_t max) { + if (score == UCC_SCORE_MAX) { + ucc_strncpy_safe(buf, "inf", max); + } else { + ucc_snprintf_safe(buf, max, "%d", score); + } + + return buf; +} + ucc_status_t ucc_coll_score_alloc(ucc_coll_score_t **score) { ucc_coll_score_t *s = ucc_malloc(sizeof(*s), "ucc_coll_score"); diff --git a/src/coll_score/ucc_coll_score.h b/src/coll_score/ucc_coll_score.h index 16f0ba0b74..fa95e6a76a 100644 --- a/src/coll_score/ucc_coll_score.h +++ b/src/coll_score/ucc_coll_score.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -63,6 +63,8 @@ typedef struct ucc_coll_score { typedef struct ucc_score_map ucc_score_map_t; +char *ucc_score_to_str(ucc_score_t score, char *buf, size_t max); + /* Allocates empty score data structure */ ucc_status_t ucc_coll_score_alloc(ucc_coll_score_t **score); @@ -77,7 +79,7 @@ ucc_status_t ucc_coll_score_add_range(ucc_coll_score_t *score, /* Releases the score data structure and all the score ranges stored there */ -void ucc_coll_score_free(ucc_coll_score_t *score); +void ucc_coll_score_free(ucc_coll_score_t *score); /* Merges 2 scores score1 and score2 into the new score "rst" selecting larger score. Ie.: rst will contain a range from score1 if either @@ -87,9 +89,9 @@ void ucc_coll_score_free(ucc_coll_score_t *score); This fn is used by CL to merge scores from multiple TLs and produce a score map. As a result the produced score map will select TL with higher score.*/ -ucc_status_t ucc_coll_score_merge(ucc_coll_score_t * score1, - ucc_coll_score_t * score2, - ucc_coll_score_t **rst, int free_inputs); +ucc_status_t ucc_coll_score_merge(ucc_coll_score_t * score1, + ucc_coll_score_t * score2, + ucc_coll_score_t **rst, int free_inputs); /* Parses SCORE string (see ucc_base_iface.c for pattern description) @@ -147,7 +149,7 @@ ucc_status_t ucc_coll_score_build_default(ucc_base_team_t *team, ucc_status_t ucc_coll_score_build_map(ucc_coll_score_t *score, ucc_score_map_t **map); -void ucc_coll_score_free_map(ucc_score_map_t *map); +void ucc_coll_score_free_map(ucc_score_map_t *map); /* Initializes task based on args selection and score map. Checks fallbacks if necessary. */ diff --git a/src/coll_score/ucc_coll_score_map.c b/src/coll_score/ucc_coll_score_map.c index 5b67260bd8..037476efb2 100644 --- a/src/coll_score/ucc_coll_score_map.c +++ b/src/coll_score/ucc_coll_score_map.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -74,10 +74,9 @@ void ucc_coll_score_free_map(ucc_score_map_t *map) ucc_free(map); } -static -ucc_status_t ucc_coll_score_map_lookup(ucc_score_map_t *map, - ucc_base_coll_args_t *bargs, - ucc_msg_range_t **range) +static ucc_status_t ucc_coll_score_map_lookup(ucc_score_map_t *map, + ucc_base_coll_args_t *bargs, + ucc_msg_range_t **range) { ucc_memory_type_t mt = ucc_coll_args_mem_type(&bargs->args, map->team_rank); @@ -85,11 +84,12 @@ ucc_status_t ucc_coll_score_map_lookup(ucc_score_map_t *map, size_t msgsize = ucc_coll_args_msgsize(&bargs->args, map->team_rank, map->team_size); - ucc_list_link_t *list; - ucc_msg_range_t *r; + ucc_list_link_t *list; + ucc_msg_range_t *r; if (mt == UCC_MEMORY_TYPE_ASYMMETRIC) { /* TODO */ + ucc_debug("asymmetric memory type is not supported"); return UCC_ERR_NOT_SUPPORTED; } else if (mt == UCC_MEMORY_TYPE_NOT_APPLY) { /* Temporary solution: for Barrier, Fanin, Fanout - use @@ -122,7 +122,9 @@ ucc_status_t ucc_coll_init(ucc_score_map_t *map, ucc_status_t status; status = ucc_coll_score_map_lookup(map, bargs, &r); - if (UCC_OK != status) { + if (ucc_unlikely(UCC_OK != status)) { + ucc_debug("coll_score_map lookup failed %d (%s)", + status, ucc_status_string(status)); return status; } @@ -160,11 +162,12 @@ ucc_status_t ucc_coll_init(ucc_score_map_t *map, void ucc_coll_score_map_print_info(const ucc_score_map_t *map) { - size_t left; - ucc_msg_range_t *range; - int i, j, all_empty; - char range_str[128]; - char coll_str[1024]; + size_t left; + ucc_msg_range_t *range; + int i, j, all_empty; + char score_str[32]; + char range_str[128]; + char coll_str[1024]; for (i = 0; i < UCC_COLL_TYPE_NUM; i++) { all_empty = 1; @@ -191,10 +194,12 @@ void ucc_coll_score_map_print_info(const ucc_score_map_t *map) super.list_elem) { ucc_memunits_range_str(range->start, range->end, range_str, sizeof(range_str)); - STR_APPEND(coll_str, left, 256, "{%s}:%s:%u ", + ucc_score_to_str(range->super.score, score_str, + sizeof(score_str)); + STR_APPEND(coll_str, left, 256, "{%s}:%s:%s ", range_str, range->super.team->context->lib->log_component.name, - range->super.score); + score_str); } STR_APPEND(coll_str, left, 4, "\n"); } diff --git a/src/components/cl/hier/Makefile.am b/src/components/cl/hier/Makefile.am index c99d72e96f..243f5811e8 100644 --- a/src/components/cl/hier/Makefile.am +++ b/src/components/cl/hier/Makefile.am @@ -25,6 +25,11 @@ bcast = \ bcast/bcast.c \ bcast/bcast_2step.c +reduce = \ + reduce/reduce.h \ + reduce/reduce.c \ + reduce/reduce_2step.c + sources = \ cl_hier.h \ cl_hier.c \ @@ -37,7 +42,8 @@ sources = \ $(alltoallv) \ $(alltoall) \ $(barrier) \ - $(bcast) + $(bcast) \ + $(reduce) module_LTLIBRARIES = libucc_cl_hier.la libucc_cl_hier_la_SOURCES = $(sources) diff --git a/src/components/cl/hier/cl_hier.c b/src/components/cl/hier/cl_hier.c index 87f2b2c370..edbb469d78 100644 --- a/src/components/cl/hier/cl_hier.c +++ b/src/components/cl/hier/cl_hier.c @@ -71,6 +71,11 @@ static ucc_config_field_t ucc_cl_hier_lib_config_table[] = { ucc_offsetof(ucc_cl_hier_lib_config_t, bcast_2step_pipeline), UCC_CONFIG_TYPE_PIPELINE_PARAMS}, + {"REDUCE_2STEP_PIPELINE", "n", + "Pipelining settings for RAB reduce algorithm", + ucc_offsetof(ucc_cl_hier_lib_config_t, reduce_2step_pipeline), + UCC_CONFIG_TYPE_PIPELINE_PARAMS}, + {NULL}}; static ucs_config_field_t ucc_cl_hier_context_config_table[] = { diff --git a/src/components/cl/hier/cl_hier.h b/src/components/cl/hier/cl_hier.h index 8f538c1d7b..ef40f33118 100644 --- a/src/components/cl/hier/cl_hier.h +++ b/src/components/cl/hier/cl_hier.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) Meta Platforms, Inc. and affiliates. 2022. * * See file LICENSE for terms. @@ -53,6 +53,7 @@ typedef struct ucc_cl_hier_lib_config { ucc_pipeline_params_t allreduce_split_rail_pipeline; ucc_pipeline_params_t allreduce_rab_pipeline; ucc_pipeline_params_t bcast_2step_pipeline; + ucc_pipeline_params_t reduce_2step_pipeline; } ucc_cl_hier_lib_config_t; typedef struct ucc_cl_hier_context_config { @@ -109,8 +110,12 @@ typedef struct ucc_cl_hier_team { UCC_CLASS_DECLARE(ucc_cl_hier_team_t, ucc_base_context_t *, const ucc_base_team_params_t *); -#define UCC_CL_HIER_SUPPORTED_COLLS \ - (UCC_COLL_TYPE_ALLTOALL | UCC_COLL_TYPE_ALLTOALLV) +#define UCC_CL_HIER_SUPPORTED_COLLS \ + (UCC_COLL_TYPE_ALLTOALL | \ + UCC_COLL_TYPE_ALLTOALLV | \ + UCC_COLL_TYPE_ALLREDUCE | \ + UCC_COLL_TYPE_BARRIER | \ + UCC_COLL_TYPE_BCAST) ucc_status_t ucc_cl_hier_coll_init(ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, diff --git a/src/components/cl/hier/cl_hier_coll.c b/src/components/cl/hier/cl_hier_coll.c index acdb243ddd..b7fd507843 100644 --- a/src/components/cl/hier/cl_hier_coll.c +++ b/src/components/cl/hier/cl_hier_coll.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -13,7 +13,8 @@ const char * ucc_cl_hier_default_alg_select_str[UCC_CL_HIER_N_DEFAULT_ALG_SELECT_STR] = { UCC_CL_HIER_ALLREDUCE_DEFAULT_ALG_SELECT_STR, - UCC_CL_HIER_BCAST_DEFAULT_ALG_SELECT_STR}; + UCC_CL_HIER_BCAST_DEFAULT_ALG_SELECT_STR, + UCC_CL_HIER_REDUCE_DEFAULT_ALG_SELECT_STR}; ucc_status_t ucc_cl_hier_coll_init(ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, @@ -22,14 +23,16 @@ ucc_status_t ucc_cl_hier_coll_init(ucc_base_coll_args_t *coll_args, switch (coll_args->args.coll_type) { case UCC_COLL_TYPE_ALLREDUCE: return ucc_cl_hier_allreduce_rab_init(coll_args, team, task); - case UCC_COLL_TYPE_BARRIER: - return ucc_cl_hier_barrier_init(coll_args, team, task); case UCC_COLL_TYPE_ALLTOALL: return ucc_cl_hier_alltoall_init(coll_args, team, task); case UCC_COLL_TYPE_ALLTOALLV: return ucc_cl_hier_alltoallv_init(coll_args, team, task); + case UCC_COLL_TYPE_BARRIER: + return ucc_cl_hier_barrier_init(coll_args, team, task); case UCC_COLL_TYPE_BCAST: return ucc_cl_hier_bcast_2step_init(coll_args, team, task); + case UCC_COLL_TYPE_REDUCE: + return ucc_cl_hier_reduce_2step_init(coll_args, team, task); default: cl_error(team->context->lib, "coll_type %s is not supported", ucc_coll_type_str(coll_args->args.coll_type)); @@ -41,14 +44,16 @@ ucc_status_t ucc_cl_hier_coll_init(ucc_base_coll_args_t *coll_args, static inline int alg_id_from_str(ucc_coll_type_t coll_type, const char *str) { switch (coll_type) { + case UCC_COLL_TYPE_ALLREDUCE: + return ucc_cl_hier_allreduce_alg_from_str(str); case UCC_COLL_TYPE_ALLTOALLV: return ucc_cl_hier_alltoallv_alg_from_str(str); case UCC_COLL_TYPE_ALLTOALL: return ucc_cl_hier_alltoall_alg_from_str(str); - case UCC_COLL_TYPE_ALLREDUCE: - return ucc_cl_hier_allreduce_alg_from_str(str); case UCC_COLL_TYPE_BCAST: return ucc_cl_hier_bcast_alg_from_str(str); + case UCC_COLL_TYPE_REDUCE: + return ucc_cl_hier_reduce_alg_from_str(str); default: break; } @@ -66,6 +71,19 @@ ucc_status_t ucc_cl_hier_alg_id_to_init(int alg_id, const char *alg_id_str, } switch (coll_type) { + case UCC_COLL_TYPE_ALLREDUCE: + switch (alg_id) { + case UCC_CL_HIER_ALLREDUCE_ALG_RAB: + *init = ucc_cl_hier_allreduce_rab_init; + break; + case UCC_CL_HIER_ALLREDUCE_ALG_SPLIT_RAIL: + *init = ucc_cl_hier_allreduce_split_rail_init; + break; + default: + status = UCC_ERR_INVALID_PARAM; + break; + }; + break; case UCC_COLL_TYPE_ALLTOALLV: switch (alg_id) { case UCC_CL_HIER_ALLTOALLV_ALG_NODE_SPLIT: @@ -86,28 +104,25 @@ ucc_status_t ucc_cl_hier_alg_id_to_init(int alg_id, const char *alg_id_str, break; }; break; - case UCC_COLL_TYPE_ALLREDUCE: + case UCC_COLL_TYPE_BCAST: switch (alg_id) { - case UCC_CL_HIER_ALLREDUCE_ALG_RAB: - *init = ucc_cl_hier_allreduce_rab_init; - break; - case UCC_CL_HIER_ALLREDUCE_ALG_SPLIT_RAIL: - *init = ucc_cl_hier_allreduce_split_rail_init; + case UCC_CL_HIER_BCAST_ALG_2STEP: + *init = ucc_cl_hier_bcast_2step_init; break; default: status = UCC_ERR_INVALID_PARAM; break; }; break; - case UCC_COLL_TYPE_BCAST: - switch (alg_id) { - case UCC_CL_HIER_BCAST_ALG_2STEP: - *init = ucc_cl_hier_bcast_2step_init; + case UCC_COLL_TYPE_REDUCE: + switch(alg_id) { + case UCC_CL_HIER_REDUCE_ALG_2STEP: + *init = ucc_cl_hier_reduce_2step_init; break; default: status = UCC_ERR_INVALID_PARAM; break; - }; + } break; default: status = UCC_ERR_NOT_SUPPORTED; diff --git a/src/components/cl/hier/cl_hier_coll.h b/src/components/cl/hier/cl_hier_coll.h index 3258796675..5a1e294afe 100644 --- a/src/components/cl/hier/cl_hier_coll.h +++ b/src/components/cl/hier/cl_hier_coll.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -14,8 +14,9 @@ #include "alltoall/alltoall.h" #include "barrier/barrier.h" #include "bcast/bcast.h" +#include "reduce/reduce.h" -#define UCC_CL_HIER_N_DEFAULT_ALG_SELECT_STR 2 +#define UCC_CL_HIER_N_DEFAULT_ALG_SELECT_STR 3 extern const char *ucc_cl_hier_default_alg_select_str[UCC_CL_HIER_N_DEFAULT_ALG_SELECT_STR]; diff --git a/src/components/cl/hier/cl_hier_team.c b/src/components/cl/hier/cl_hier_team.c index 31f0e9f707..32ef7e2f93 100644 --- a/src/components/cl/hier/cl_hier_team.c +++ b/src/components/cl/hier/cl_hier_team.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -43,6 +43,11 @@ UCC_CLASS_INIT_FUNC(ucc_cl_hier_team_t, ucc_base_context_t *cl_context, ucc_config_names_array_t *tls; ucc_subset_t subset; struct ucc_team_team_desc *d; + ucc_tl_context_t *tl_ctx; + ucc_tl_lib_t *tl_lib; + ucc_base_lib_attr_t attr; + + if (!params->team->topo) { cl_debug(cl_context->lib, "can't create hier team without topology data"); @@ -74,6 +79,13 @@ UCC_CLASS_INIT_FUNC(ucc_cl_hier_team_t, ucc_base_context_t *cl_context, hs->n_tls = 0; tls = &lib->cfg.sbgp_tls[i].array; for (j = 0; j < tls->count; j++) { + if (hs->n_tls == CL_HIER_MAX_SBGP_TLS) { + cl_debug(cl_context->lib, + "skipping tl context %s for %s sbgp: " + "max number of TLs per SBGP is reached", + tls->names[j], ucc_sbgp_str(hs->sbgp_type)); + continue; + } status = ucc_tl_context_get(ctx->super.super.ucc_context, tls->names[j], &hs->tl_ctxs[hs->n_tls]); @@ -81,11 +93,37 @@ UCC_CLASS_INIT_FUNC(ucc_cl_hier_team_t, ucc_base_context_t *cl_context, cl_debug(cl_context->lib, "tl context %s is not available for sbgp %s", tls->names[j], ucc_sbgp_str(hs->sbgp_type)); - } else { - hs->n_tls++; - n_sbgp_teams++; - ucc_assert(hs->n_tls <= CL_HIER_MAX_SBGP_TLS); + continue; } + attr.mask = UCC_BASE_LIB_ATTR_FIELD_MIN_TEAM_SIZE | + UCC_BASE_LIB_ATTR_FIELD_MAX_TEAM_SIZE; + tl_ctx = hs->tl_ctxs[hs->n_tls]; + tl_lib = ucc_derived_of(tl_ctx->super.lib, ucc_tl_lib_t); + status = tl_lib->iface->lib.get_attr(tl_ctx->super.lib, + &attr); + if (status != UCC_OK) { + cl_debug(cl_context->lib, + "failed to get attributes for tl context %s", + tls->names[j]); + ucc_tl_context_put(tl_ctx); + continue; + } + + if (hs->sbgp->group_size < attr.min_team_size || + hs->sbgp->group_size > attr.max_team_size) { + cl_debug(cl_context->lib, + "tl context %s is not suitable for sbgp %s" + "sbgp: sbgp size %d is not in range [%d; %d]", + tls->names[j], ucc_sbgp_str(hs->sbgp_type), + hs->sbgp->group_size, + attr.min_team_size, attr.max_team_size); + ucc_tl_context_put(tl_ctx); + continue; + } + + hs->n_tls++; + n_sbgp_teams++; + ucc_assert(hs->n_tls <= CL_HIER_MAX_SBGP_TLS); } } } @@ -325,7 +363,7 @@ ucc_status_t ucc_cl_hier_team_get_scores(ucc_base_team_t *cl_team, team_info.init = ucc_cl_hier_coll_init; team_info.num_mem_types = 0; team_info.supported_mem_types = NULL; /* all memory types supported*/ - team_info.supported_colls = UCC_COLL_TYPE_ALL; + team_info.supported_colls = UCC_CL_HIER_SUPPORTED_COLLS; team_info.size = UCC_CL_TEAM_SIZE(team); status = ucc_coll_score_alloc(&score); diff --git a/src/components/cl/hier/reduce/reduce.c b/src/components/cl/hier/reduce/reduce.c new file mode 100644 index 0000000000..6ba9e72892 --- /dev/null +++ b/src/components/cl/hier/reduce/reduce.c @@ -0,0 +1,17 @@ +/** + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include "reduce.h" +#include "../reduce/reduce.h" + +ucc_base_coll_alg_info_t + ucc_cl_hier_reduce_algs[UCC_CL_HIER_REDUCE_ALG_LAST + 1] = { + [UCC_CL_HIER_REDUCE_ALG_2STEP] = + {.id = UCC_CL_HIER_REDUCE_ALG_2STEP, + .name = "2step", + .desc = "intra-node and inter-node reduces executed in parallel"}, + [UCC_CL_HIER_REDUCE_ALG_LAST] = { + .id = 0, .name = NULL, .desc = NULL}}; diff --git a/src/components/cl/hier/reduce/reduce.h b/src/components/cl/hier/reduce/reduce.h new file mode 100644 index 0000000000..fdea260996 --- /dev/null +++ b/src/components/cl/hier/reduce/reduce.h @@ -0,0 +1,38 @@ +/** + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#ifndef REDUCE_H_ +#define REDUCE_H_ +#include "../cl_hier.h" + +enum +{ + UCC_CL_HIER_REDUCE_ALG_2STEP, + UCC_CL_HIER_REDUCE_ALG_LAST, +}; + +extern ucc_base_coll_alg_info_t + ucc_cl_hier_reduce_algs[UCC_CL_HIER_REDUCE_ALG_LAST + 1]; + +#define UCC_CL_HIER_REDUCE_DEFAULT_ALG_SELECT_STR "reduce:0-4k:@2step" + +ucc_status_t ucc_cl_hier_reduce_2step_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team, + ucc_coll_task_t **task); + +static inline int ucc_cl_hier_reduce_alg_from_str(const char *str) +{ + int i; + + for (i = 0; i < UCC_CL_HIER_REDUCE_ALG_LAST; i++) { + if (0 == strcasecmp(str, ucc_cl_hier_reduce_algs[i].name)) { + break; + } + } + return i; +} + +#endif diff --git a/src/components/cl/hier/reduce/reduce_2step.c b/src/components/cl/hier/reduce/reduce_2step.c new file mode 100644 index 0000000000..bb10434058 --- /dev/null +++ b/src/components/cl/hier/reduce/reduce_2step.c @@ -0,0 +1,318 @@ +/** + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include "reduce.h" +#include "core/ucc_team.h" +#include "../cl_hier_coll.h" + +#define MAX_AR_2STEP_TASKS 3 + +static ucc_status_t ucc_cl_hier_reduce_2step_start(ucc_coll_task_t *task) +{ + UCC_CL_HIER_PROFILE_REQUEST_EVENT(task, "cl_hier_reduce_2step_start", 0); + return ucc_schedule_start(task); +} + +static ucc_status_t ucc_cl_hier_reduce_2step_finalize(ucc_coll_task_t *task) +{ + ucc_cl_hier_schedule_t *schedule = ucc_derived_of(task, ucc_cl_hier_schedule_t); + ucc_status_t status; + + UCC_CL_HIER_PROFILE_REQUEST_EVENT(task, "cl_hier_reduce_2step_finalize", 0); + status = ucc_schedule_finalize(task); + if (schedule->scratch) { + ucc_mc_free(schedule->scratch); + } + ucc_cl_hier_put_schedule(&schedule->super.super); + return status; +} + +static inline ucc_rank_t +find_root_net_rank(ucc_host_id_t root_host_id, ucc_cl_hier_team_t *cl_team) +{ + ucc_rank_t net_rank = UCC_RANK_INVALID; + ucc_sbgp_t *sbgp = cl_team->sbgps[UCC_HIER_SBGP_NODE_LEADERS].sbgp; + ucc_team_t *core_team = cl_team->super.super.params.team; + ucc_rank_t i, rank; + + for (i = 0; i < sbgp->group_size; i++) { + rank = ucc_ep_map_eval(sbgp->map, i); + if (ucc_team_rank_host_id(rank, core_team) == root_host_id) { + net_rank = i; + break; + } + } + return net_rank; +} + +static inline ucc_rank_t +find_root_node_rank(ucc_rank_t root, ucc_cl_hier_team_t *cl_team) +{ + ucc_rank_t node_rank = UCC_RANK_INVALID; + ucc_sbgp_t *sbgp = cl_team->sbgps[UCC_HIER_SBGP_NODE].sbgp; + ucc_rank_t i; + + for (i = 0; i < sbgp->group_size; i++) { + if (ucc_ep_map_eval(sbgp->map, i) == root) { + node_rank = i; + break; + } + } + return node_rank; +} + +static ucc_status_t +ucc_cl_hier_reduce_2step_init_schedule(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team, + ucc_schedule_t **sched_p, int n_frags) +{ + ucc_cl_hier_team_t *cl_team = ucc_derived_of(team, ucc_cl_hier_team_t); + ucc_team_t *core_team = team->params.team; + ucc_coll_task_t *tasks[2] = {NULL, NULL}; + ucc_rank_t root = coll_args->args.root; + ucc_rank_t rank = UCC_TL_TEAM_RANK(cl_team); + ucc_base_coll_args_t args = *coll_args; + size_t count = (rank == root) ? + args.args.dst.info.count : + args.args.src.info.count; + ucc_cl_hier_schedule_t *cl_schedule; + ucc_schedule_t *schedule; + ucc_status_t status; + int n_tasks, i, first_task, root_on_local_node; + + root_on_local_node = ucc_team_ranks_on_same_node(root, rank, core_team); + n_tasks = 0; + first_task = 0; + + if (root != rank) { + args.args.dst.info.count = args.args.src.info.count; + args.args.dst.info.mem_type = args.args.src.info.mem_type; + args.args.dst.info.datatype = args.args.src.info.datatype; + args.args.mask &= (~UCC_COLL_ARGS_FLAG_IN_PLACE); + } + + cl_schedule = ucc_cl_hier_get_schedule(cl_team); + if (ucc_unlikely(!cl_schedule)) { + return UCC_ERR_NO_MEMORY; + } + + schedule = &cl_schedule->super.super; + status = ucc_schedule_init(schedule, &args, team); + if (ucc_unlikely(UCC_OK != status)) { + goto out; + } + + args.max_frag_count = ucc_buffer_block_count(count, n_frags, 0); + if (n_frags > 1) { + args.mask |= UCC_BASE_CARGS_MAX_FRAG_COUNT; + } + + ucc_assert(SBGP_ENABLED(cl_team, NODE_LEADERS) || + SBGP_ENABLED(cl_team, NODE)); + if (SBGP_ENABLED(cl_team, NODE)) { + args.args.root = root_on_local_node + ? find_root_node_rank(root, cl_team) : 0; + + if ((root != rank) && SBGP_ENABLED(cl_team, NODE_LEADERS)) { + status = ucc_mc_alloc(&cl_schedule->scratch, + args.max_frag_count * + ucc_dt_size(args.args.src.info.datatype), + args.args.src.info.mem_type); + if (ucc_unlikely(UCC_OK != status)) { + goto out; + } + args.args.dst.info.buffer = cl_schedule->scratch->addr; + if (root_on_local_node) { + first_task = 1; + args.args.src.info.buffer = cl_schedule->scratch->addr; + } + } + status = ucc_coll_init(SCORE_MAP(cl_team, NODE), &args, &tasks[n_tasks]); + if (ucc_unlikely(UCC_OK != status)) { + goto out; + } + n_tasks++; + } + + if (SBGP_ENABLED(cl_team, NODE_LEADERS)) { + if (n_tasks == 1) { + if (root != rank) { + args.args.src.info.buffer = root_on_local_node ? + coll_args->args.src.info.buffer : cl_schedule->scratch->addr; + } else { + args.args.mask |= UCC_COLL_ARGS_FIELD_FLAGS; + args.args.flags |= UCC_COLL_ARGS_FLAG_IN_PLACE; + } + } + args.args.root = find_root_net_rank(ucc_team_rank_host_id(root, core_team), + cl_team); + status = ucc_coll_init(SCORE_MAP(cl_team, NODE_LEADERS), + &args, &tasks[n_tasks]); + if (ucc_unlikely(UCC_OK != status)) { + goto out; + } + n_tasks++; + } + + ucc_task_subscribe_dep(&schedule->super, tasks[first_task], + UCC_EVENT_SCHEDULE_STARTED); + ucc_schedule_add_task(schedule, tasks[first_task]); + + if (n_tasks > 1) { + ucc_task_subscribe_dep(tasks[first_task], + tasks[(first_task + 1) % 2], UCC_EVENT_COMPLETED); + ucc_schedule_add_task(schedule, tasks[(first_task + 1) % 2]); + } + + schedule->super.post = ucc_cl_hier_reduce_2step_start; + schedule->super.progress = NULL; + schedule->super.finalize = ucc_cl_hier_reduce_2step_finalize; + schedule->super.triggered_post = ucc_triggered_post; + *sched_p = schedule; + return UCC_OK; + +out: + for (i = 0; i < n_tasks; i++) { + tasks[i]->finalize(tasks[i]); + } + ucc_cl_hier_put_schedule(schedule); + return status; +} + +static ucc_status_t +ucc_cl_hier_reduce_2step_frag_init(ucc_base_coll_args_t *coll_args, + ucc_schedule_pipelined_t *sp, + ucc_base_team_t *team, + ucc_schedule_t **frag_p) +{ + int n_frags = sp->super.n_tasks; + ucc_status_t status; + + status = ucc_cl_hier_reduce_2step_init_schedule(coll_args, team, frag_p, + n_frags); + return status; +} + +static ucc_status_t +ucc_cl_hier_reduce_2step_frag_setup(ucc_schedule_pipelined_t *schedule_p, + ucc_schedule_t *frag, int frag_num) +{ + ucc_cl_hier_team_t *cl_team = ucc_derived_of(schedule_p->super.super.team, + ucc_cl_hier_team_t); + ucc_coll_args_t *args = &schedule_p->super.super.bargs.args; + size_t dt_size = ucc_dt_size(args->src.info.datatype); + int n_frags = schedule_p->super.n_tasks; + ucc_rank_t root = args->root; + ucc_rank_t rank = UCC_TL_TEAM_RANK(cl_team); + size_t count = (rank == root) ? args->dst.info.count : + args->src.info.count; + size_t frag_count, frag_offset; + ucc_coll_task_t *task; + int i; + ucc_cl_hier_schedule_t *cl_schedule; + void *scratch; + + cl_schedule = ucc_derived_of(frag, ucc_cl_hier_schedule_t); + scratch = cl_schedule->scratch ? cl_schedule->scratch->addr : NULL; + frag_count = ucc_buffer_block_count(count, n_frags, frag_num); + frag_offset = ucc_buffer_block_offset(count, n_frags, frag_num); + + for (i = 0; i < frag->n_tasks; i++) { + task = frag->tasks[i]; + task->bargs.args.src.info.count = frag_count; + task->bargs.args.dst.info.count = frag_count; + if (task->bargs.args.src.info.buffer != scratch) { + task->bargs.args.src.info.buffer = + PTR_OFFSET(args->src.info.buffer, frag_offset * dt_size); + } + if (task->bargs.args.dst.info.buffer != scratch) { + task->bargs.args.dst.info.buffer = + PTR_OFFSET(args->dst.info.buffer, frag_offset * dt_size); + } + } + return UCC_OK; +} + +static ucc_status_t +ucc_cl_hier_reduce_2step_pipelined_start(ucc_coll_task_t *task) +{ + ucc_schedule_pipelined_t *schedule = + ucc_derived_of(task, ucc_schedule_pipelined_t); + + cl_debug(task->team->context->lib, + "posting reduce_2step, count %zd, dt %s" + " pdepth %d, frags_total %d", + task->bargs.args.src.info.count, + ucc_datatype_str(task->bargs.args.src.info.datatype), + schedule->n_frags, schedule->super.n_tasks); + + return ucc_schedule_pipelined_post(task); +} + +static ucc_status_t +ucc_cl_hier_reduce_2step_pipelined_finalize(ucc_coll_task_t *task) +{ + ucc_cl_hier_schedule_t *schedule = ucc_derived_of(task, + ucc_cl_hier_schedule_t); + ucc_status_t status; + + status = ucc_schedule_pipelined_finalize(&schedule->super.super.super); + ucc_cl_hier_put_schedule(&schedule->super.super); + return status; +} + +UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_reduce_2step_init, + (coll_args, team, task), + ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, + ucc_coll_task_t **task) +{ + ucc_cl_hier_team_t *cl_team = ucc_derived_of(team, ucc_cl_hier_team_t); + ucc_cl_hier_lib_config_t *cfg = &UCC_CL_HIER_TEAM_LIB(cl_team)->cfg; + ucc_cl_hier_schedule_t *schedule; + int n_frags, pipeline_depth; + ucc_status_t status; + + if (UCC_IS_PERSISTENT(coll_args->args) || + (coll_args->args.op == UCC_OP_AVG)) { + return UCC_ERR_NOT_SUPPORTED; + } + + ucc_pipeline_nfrags_pdepth(&cfg->reduce_2step_pipeline, + coll_args->args.src.info.count * + ucc_dt_size(coll_args->args.src.info.datatype), + &n_frags, &pipeline_depth); + + if (n_frags == 1) { + return ucc_cl_hier_reduce_2step_init_schedule( + coll_args, team, (ucc_schedule_t **)task, n_frags); + } + + schedule = ucc_cl_hier_get_schedule(cl_team); + if (ucc_unlikely(!schedule)) { + return UCC_ERR_NO_MEMORY; + } + + status = ucc_schedule_pipelined_init( + coll_args, team, ucc_cl_hier_reduce_2step_frag_init, + ucc_cl_hier_reduce_2step_frag_setup, pipeline_depth, n_frags, + cfg->reduce_2step_pipeline.order, &schedule->super); + + if (ucc_unlikely(status != UCC_OK)) { + cl_error(team->context->lib, + "failed to init pipelined 2step ar schedule"); + goto err_pipe_init; + } + + schedule->super.super.super.post = ucc_cl_hier_reduce_2step_pipelined_start; + schedule->super.super.super.finalize = ucc_cl_hier_reduce_2step_pipelined_finalize; + schedule->super.super.super.triggered_post = ucc_triggered_post; + *task = &schedule->super.super.super; + return UCC_OK; + +err_pipe_init: + ucc_cl_hier_put_schedule(&schedule->super.super); + return status; +} diff --git a/src/components/ec/cpu/ec_cpu.c b/src/components/ec/cpu/ec_cpu.c index 8d08d2365c..b94052002e 100644 --- a/src/components/ec/cpu/ec_cpu.c +++ b/src/components/ec/cpu/ec_cpu.c @@ -113,8 +113,12 @@ ucc_status_t ucc_cpu_executor_task_post(ucc_ee_executor_t *executor, eee_task->eee = executor; switch (task_args->task_type) { case UCC_EE_EXECUTOR_TASK_REDUCE: - status = ucc_ec_cpu_reduce((ucc_eee_task_reduce_t *)&task_args->reduce, - task_args->flags); + status = ucc_ec_cpu_reduce((ucc_eee_task_reduce_t *)&task_args->reduce, task_args->reduce.dst, + (task_args->flags & + UCC_EEE_TASK_FLAG_REDUCE_SRCS_EXT) ? + task_args->reduce.srcs_ext : + task_args->reduce.srcs, + task_args->flags); if (ucc_unlikely(UCC_OK != status)) { goto free_task; } @@ -147,7 +151,7 @@ ucc_status_t ucc_cpu_executor_task_post(ucc_ee_executor_t *executor, tr.dst = trs->dst; tr.alpha = trs->alpha; - status = ucc_ec_cpu_reduce(&tr, flags); + status = ucc_ec_cpu_reduce(&tr, tr.dst, srcs, flags); if (ucc_unlikely(UCC_OK != status)) { goto free_task; } diff --git a/src/components/ec/cpu/ec_cpu.h b/src/components/ec/cpu/ec_cpu.h index 0c6b63c92c..de96f90869 100644 --- a/src/components/ec/cpu/ec_cpu.h +++ b/src/components/ec/cpu/ec_cpu.h @@ -25,5 +25,5 @@ typedef struct ucc_ec_cpu { extern ucc_ec_cpu_t ucc_ec_cpu; -ucc_status_t ucc_ec_cpu_reduce(ucc_eee_task_reduce_t *task, uint16_t flags); +ucc_status_t ucc_ec_cpu_reduce(ucc_eee_task_reduce_t *task, void * restrict dst, void * const * restrict srcs, uint16_t flags); #endif diff --git a/src/components/ec/cpu/ec_cpu_reduce.c b/src/components/ec/cpu/ec_cpu_reduce.c index a805d67294..2936f84c83 100644 --- a/src/components/ec/cpu/ec_cpu_reduce.c +++ b/src/components/ec/cpu/ec_cpu_reduce.c @@ -12,48 +12,49 @@ do { \ size_t _i, _j; \ type _tmp; \ + size_t __count = _count; \ switch (_n_srcs) { \ case 2: \ - for (_i = 0; _i < _count; _i++) { \ + for (_i = 0; _i < __count; _i++) { \ d[_i] = OP##_2(s[0][_i], s[1][_i]); \ } \ break; \ case 3: \ - for (_i = 0; _i < _count; _i++) { \ + for (_i = 0; _i < __count; _i++) { \ d[_i] = OP##_3(s[0][_i], s[1][_i], s[2][_i]); \ } \ break; \ case 4: \ - for (_i = 0; _i < _count; _i++) { \ + for (_i = 0; _i < __count; _i++) { \ d[_i] = OP##_4(s[0][_i], s[1][_i], s[2][_i], s[3][_i]); \ } \ break; \ case 5: \ - for (_i = 0; _i < _count; _i++) { \ + for (_i = 0; _i < __count; _i++) { \ d[_i] = \ OP##_5(s[0][_i], s[1][_i], s[2][_i], s[3][_i], s[4][_i]); \ } \ break; \ case 6: \ - for (_i = 0; _i < _count; _i++) { \ + for (_i = 0; _i < __count; _i++) { \ d[_i] = OP##_6(s[0][_i], s[1][_i], s[2][_i], s[3][_i], \ s[4][_i], s[5][_i]); \ } \ break; \ case 7: \ - for (_i = 0; _i < _count; _i++) { \ + for (_i = 0; _i < __count; _i++) { \ d[_i] = OP##_7(s[0][_i], s[1][_i], s[2][_i], s[3][_i], \ s[4][_i], s[5][_i], s[6][_i]); \ } \ break; \ case 8: \ - for (_i = 0; _i < _count; _i++) { \ + for (_i = 0; _i < __count; _i++) { \ d[_i] = OP##_8(s[0][_i], s[1][_i], s[2][_i], s[3][_i], \ s[4][_i], s[5][_i], s[6][_i], s[7][_i]); \ } \ break; \ default: \ - for (_i = 0; _i < _count; _i++) { \ + for (_i = 0; _i < __count; _i++) { \ _tmp = OP##_8(s[0][_i], s[1][_i], s[2][_i], s[3][_i], \ s[4][_i], s[5][_i], s[6][_i], s[7][_i]); \ for (_j = 8; _j < _n_srcs; _j++) { \ @@ -223,47 +224,45 @@ } \ } while (0) -ucc_status_t ucc_ec_cpu_reduce(ucc_eee_task_reduce_t *task, uint16_t flags) +ucc_status_t ucc_ec_cpu_reduce(ucc_eee_task_reduce_t *task, void * restrict dst, + void * const * restrict srcs, uint16_t flags) { - void **srcs = (flags & UCC_EEE_TASK_FLAG_REDUCE_SRCS_EXT) ? task->srcs_ext - : task->srcs; - switch (task->dt) { case UCC_DT_INT8: - DO_DT_REDUCE_INT(int8_t, srcs, task->dst, task->op, task->count, + DO_DT_REDUCE_INT(int8_t, srcs, dst, task->op, task->count, task->n_srcs); break; case UCC_DT_INT16: - DO_DT_REDUCE_INT(int16_t, srcs, task->dst, task->op, task->count, + DO_DT_REDUCE_INT(int16_t, srcs, dst, task->op, task->count, task->n_srcs); break; case UCC_DT_INT32: - DO_DT_REDUCE_INT(int32_t, srcs, task->dst, task->op, task->count, + DO_DT_REDUCE_INT(int32_t, srcs, dst, task->op, task->count, task->n_srcs); break; case UCC_DT_INT64: - DO_DT_REDUCE_INT(int64_t, srcs, task->dst, task->op, task->count, + DO_DT_REDUCE_INT(int64_t, srcs, dst, task->op, task->count, task->n_srcs); break; case UCC_DT_UINT8: - DO_DT_REDUCE_INT(uint8_t, srcs, task->dst, task->op, task->count, + DO_DT_REDUCE_INT(uint8_t, srcs, dst, task->op, task->count, task->n_srcs); break; case UCC_DT_UINT16: - DO_DT_REDUCE_INT(uint16_t, srcs, task->dst, task->op, task->count, + DO_DT_REDUCE_INT(uint16_t, srcs, dst, task->op, task->count, task->n_srcs); break; case UCC_DT_UINT32: - DO_DT_REDUCE_INT(uint32_t, srcs, task->dst, task->op, task->count, + DO_DT_REDUCE_INT(uint32_t, srcs, dst, task->op, task->count, task->n_srcs); break; case UCC_DT_UINT64: - DO_DT_REDUCE_INT(uint64_t, srcs, task->dst, task->op, task->count, + DO_DT_REDUCE_INT(uint64_t, srcs, dst, task->op, task->count, task->n_srcs); break; case UCC_DT_FLOAT32: #if SIZEOF_FLOAT == 4 - DO_DT_REDUCE_FLOAT(float, srcs, task->dst, task->op, task->count, + DO_DT_REDUCE_FLOAT(float, srcs, dst, task->op, task->count, task->n_srcs); break; #else @@ -271,7 +270,7 @@ ucc_status_t ucc_ec_cpu_reduce(ucc_eee_task_reduce_t *task, uint16_t flags) #endif case UCC_DT_FLOAT64: #if SIZEOF_DOUBLE == 8 - DO_DT_REDUCE_FLOAT(double, srcs, task->dst, task->op, task->count, + DO_DT_REDUCE_FLOAT(double, srcs, dst, task->op, task->count, task->n_srcs); break; #else @@ -279,19 +278,19 @@ ucc_status_t ucc_ec_cpu_reduce(ucc_eee_task_reduce_t *task, uint16_t flags) #endif case UCC_DT_FLOAT128: #if SIZEOF_LONG_DOUBLE == 16 - DO_DT_REDUCE_FLOAT(long double, srcs, task->dst, task->op, task->count, + DO_DT_REDUCE_FLOAT(long double, srcs, dst, task->op, task->count, task->n_srcs); break; #else return UCC_ERR_NOT_SUPPORTED; #endif case UCC_DT_BFLOAT16: - DO_DT_REDUCE_BFLOAT16(srcs, task->dst, task->op, task->count, + DO_DT_REDUCE_BFLOAT16(srcs, dst, task->op, task->count, task->n_srcs); break; case UCC_DT_FLOAT32_COMPLEX: #if SIZEOF_FLOAT__COMPLEX == 8 - DO_DT_REDUCE_FLOAT_COMPLEX(float complex, srcs, task->dst, task->op, + DO_DT_REDUCE_FLOAT_COMPLEX(float complex, srcs, dst, task->op, task->count, task->n_srcs); break; #else @@ -299,7 +298,7 @@ ucc_status_t ucc_ec_cpu_reduce(ucc_eee_task_reduce_t *task, uint16_t flags) #endif case UCC_DT_FLOAT64_COMPLEX: #if SIZEOF_DOUBLE__COMPLEX == 16 - DO_DT_REDUCE_FLOAT_COMPLEX(double complex, srcs, task->dst, task->op, + DO_DT_REDUCE_FLOAT_COMPLEX(double complex, srcs, dst, task->op, task->count, task->n_srcs); break; #else @@ -307,7 +306,7 @@ ucc_status_t ucc_ec_cpu_reduce(ucc_eee_task_reduce_t *task, uint16_t flags) #endif case UCC_DT_FLOAT128_COMPLEX: #if SIZEOF_LONG_DOUBLE__COMPLEX == 32 - DO_DT_REDUCE_FLOAT_COMPLEX(long double complex, srcs, task->dst, + DO_DT_REDUCE_FLOAT_COMPLEX(long double complex, srcs, dst, task->op, task->count, task->n_srcs); break; #else diff --git a/src/components/ec/cuda/ec_cuda.c b/src/components/ec/cuda/ec_cuda.c index dd721e1f50..5dbb87fd4b 100644 --- a/src/components/ec/cuda/ec_cuda.c +++ b/src/components/ec/cuda/ec_cuda.c @@ -282,7 +282,9 @@ ucc_status_t ucc_ec_cuda_get_resources(ucc_ec_cuda_resources_t **resources) #else status = CUDADRV_FUNC(cuCtxGetId(cu_ctx, &cu_ctx_id)); if (ucc_unlikely(status != UCC_OK)) { - ec_error(&ucc_ec_cuda.super, "failed to get currect CUDA context ID"); + /* worakround for pytorch, progress thread doesn't have cuda context for GPU 0*/ + cu_ctx_id = 0x12345; + ec_debug(&ucc_ec_cuda.super, "failed to get currect CUDA context ID"); } #endif diff --git a/src/components/mc/cuda/mc_cuda.c b/src/components/mc/cuda/mc_cuda.c index aa2638b9da..72b73b4e67 100644 --- a/src/components/mc/cuda/mc_cuda.c +++ b/src/components/mc/cuda/mc_cuda.c @@ -368,7 +368,9 @@ ucc_status_t ucc_mc_cuda_get_resources(ucc_mc_cuda_resources_t **resources) #else status = CUDADRV_FUNC(cuCtxGetId(cu_ctx, &cu_ctx_id)); if (ucc_unlikely(status != UCC_OK)) { - mc_error(&ucc_mc_cuda.super, "failed to get currect CUDA context ID"); + /* worakround for pytorch, progress thread doesn't have cuda context for GPU 0*/ + cu_ctx_id = 0x12345; + mc_debug(&ucc_mc_cuda.super, "failed to get currect CUDA context ID"); } #endif diff --git a/src/components/tl/cuda/tl_cuda_team.c b/src/components/tl/cuda/tl_cuda_team.c index faa2ad89cf..64123a8cea 100644 --- a/src/components/tl/cuda/tl_cuda_team.c +++ b/src/components/tl/cuda/tl_cuda_team.c @@ -219,6 +219,11 @@ ucc_status_t ucc_tl_cuda_team_create_test(ucc_base_team_t *tl_team) } team->oob.req_free(team->oob_req); team->oob_req = (void*)0x1; + + for (i = 0; i < UCC_TL_TEAM_SIZE(team); i++) { + team->scratch.rem[i] = NULL; + } + status = ucc_tl_cuda_team_topo_create(&team->super, &team->topo); if (status != UCC_OK) { goto exit_err; diff --git a/src/components/tl/mlx5/Makefile.am b/src/components/tl/mlx5/Makefile.am index 2ac9dc91c7..01b94cfa1d 100644 --- a/src/components/tl/mlx5/Makefile.am +++ b/src/components/tl/mlx5/Makefile.am @@ -22,6 +22,7 @@ mcast = \ mcast/p2p/ucc_tl_mlx5_mcast_p2p.h \ mcast/p2p/ucc_tl_mlx5_mcast_p2p.c \ mcast/tl_mlx5_mcast_progress.h \ + mcast/tl_mlx5_mcast_progress.c \ mcast/tl_mlx5_mcast_helper.h \ mcast/tl_mlx5_mcast_helper.c \ mcast/tl_mlx5_mcast_team.c @@ -51,7 +52,7 @@ libucc_tl_mlx5_la_SOURCES = $(sources) libucc_tl_mlx5_la_CPPFLAGS = $(AM_CPPFLAGS) $(BASE_CPPFLAGS) libucc_tl_mlx5_la_CFLAGS = $(BASE_CFLAGS) libucc_tl_mlx5_la_LDFLAGS = -version-info $(SOVERSION) --as-needed -libucc_tl_mlx5_la_LIBADD = $(UCC_TOP_BUILDDIR)/src/libucc.la $(IBVERBS_LIBADD) $(MLX5DV_LIBADD) +libucc_tl_mlx5_la_LIBADD = $(UCC_TOP_BUILDDIR)/src/libucc.la $(IBVERBS_LIBADD) $(MLX5DV_LIBADD) $(RDMACM_LIBADD) include $(top_srcdir)/config/module.am diff --git a/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c b/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c index c8dca25fc3..d5c5d9dfb4 100644 --- a/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c +++ b/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ diff --git a/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.h b/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.h index 6e19e59dde..e82f7546a7 100644 --- a/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.h +++ b/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h index cab8046281..1d08f99edf 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -16,6 +16,7 @@ #include "components/tl/ucc_tl.h" #include "components/tl/ucc_tl_log.h" #include "utils/ucc_rcache.h" +#include "core/ucc_service_coll.h" #define POLL_PACKED 16 #define REL_DONE ((void*)-1) @@ -63,14 +64,14 @@ typedef struct ucc_tl_mlx5_mcast_p2p_completion_obj { typedef int (*ucc_tl_mlx5_mcast_p2p_wait_cb_fn_t)(void *wait_arg); -typedef int (*ucc_tl_mlx5_mcast_p2p_send_nb_fn_t)(void* src, size_t size, - ucc_rank_t rank, void *context, - ucc_tl_mlx5_mcast_p2p_completion_obj_t *compl_obj); +typedef ucc_status_t (*ucc_tl_mlx5_mcast_p2p_send_nb_fn_t)(void* src, size_t size, + ucc_rank_t rank, void *context, + ucc_tl_mlx5_mcast_p2p_completion_obj_t *compl_obj); -typedef int (*ucc_tl_mlx5_mcast_p2p_recv_nb_fn_t)(void* src, size_t size, - ucc_rank_t rank, void *context, - ucc_tl_mlx5_mcast_p2p_completion_obj_t *compl_obj); +typedef ucc_status_t (*ucc_tl_mlx5_mcast_p2p_recv_nb_fn_t)(void* src, size_t size, + ucc_rank_t rank, void *context, + ucc_tl_mlx5_mcast_p2p_completion_obj_t *compl_obj); typedef struct ucc_tl_mlx5_mcast_p2p_interface { ucc_tl_mlx5_mcast_p2p_send_nb_fn_t send_nb; @@ -119,6 +120,7 @@ typedef struct ucc_tl_mlx5_mcast_rcache_region { } ucc_tl_mlx5_mcast_rcache_region_t; typedef struct ucc_tl_mlx5_mcast_ctx_params { + int mcast_enabled; char *ib_dev_name; int print_nack_stats; int timeout; @@ -142,11 +144,19 @@ typedef struct ucc_tl_mlx5_mcast_coll_context { ucc_base_lib_t *lib; } ucc_tl_mlx5_mcast_coll_context_t; +typedef struct ucc_tl_mlx5_mcast_join_info_t { + ucc_status_t status; + uint16_t dlid; + union ibv_gid dgid; +} ucc_tl_mlx5_mcast_join_info_t; + typedef struct ucc_tl_mlx5_mcast_context { ucc_thread_mode_t tm; ucc_tl_mlx5_mcast_coll_context_t mcast_context; ucc_tl_mlx5_mcast_context_config_t cfg; ucc_mpool_t req_mp; + int mcast_enabled; + int mcast_ready; ucc_tl_mlx5_mcast_oob_ctx_t oob_ctx; } ucc_tl_mlx5_mcast_context_t; @@ -218,13 +228,18 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm { void *p2p_ctx; ucc_base_lib_t *lib; struct sockaddr_in6 mcast_addr; - int parents[MAX_COMM_POW2]; - int children[MAX_COMM_POW2]; + ucc_rank_t parents[MAX_COMM_POW2]; + ucc_rank_t children[MAX_COMM_POW2]; int nack_requests; int nacks_counter; int n_prep_reliable; int n_mcast_reliable; int wsize; + ucc_tl_mlx5_mcast_join_info_t *group_setup_info; + ucc_service_coll_req_t *group_setup_info_req; + ucc_status_t (*bcast_post) (void*, void*, size_t, ucc_rank_t, ucc_service_coll_req_t**); + ucc_status_t (*bcast_test) (ucc_service_coll_req_t*); + struct rdma_cm_event *event; struct pp_packet *r_window[1]; // do not add any new variable after here } ucc_tl_mlx5_mcast_coll_comm_t; @@ -352,6 +367,8 @@ ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *tl_context, const ucc_base_team_params_t *params, ucc_tl_mlx5_mcast_coll_comm_init_spec_t *mcast_conf); +ucc_status_t ucc_tl_mlx5_mcast_team_test(ucc_base_team_t *team); + ucc_status_t ucc_tl_mlx5_mcast_coll_init(ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, ucc_coll_task_t **task_h); @@ -359,4 +376,6 @@ ucc_status_t ucc_tl_mlx5_mcast_coll_init(ucc_base_coll_args_t *coll_args, ucc_status_t ucc_tl_mlx5_mcast_context_init(ucc_tl_mlx5_mcast_context_t *mcast_ctx, ucc_tl_mlx5_mcast_ctx_params_t *mcast_ctx_conf); + +ucc_status_t ucc_tl_mlx5_mcast_clean_ctx(ucc_tl_mlx5_mcast_coll_context_t *ctx); #endif diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c index 1cd2f56512..4669c88640 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c @@ -1,22 +1,248 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "tl_mlx5_coll.h" #include "tl_mlx5_mcast_helper.h" +#include "tl_mlx5_mcast_rcache.h" -ucc_status_t ucc_tl_mlx5_mcast_test(ucc_tl_mlx5_mcast_coll_req_t* req /* NOLINT */) +static inline ucc_status_t ucc_tl_mlx5_mcast_r_window_recycle(ucc_tl_mlx5_mcast_coll_comm_t *comm, + ucc_tl_mlx5_mcast_coll_req_t *req) { - return UCC_ERR_NOT_SUPPORTED; + ucc_status_t status = UCC_OK; + int wsize = comm->wsize; + int num_free_win = wsize - (comm->psn - comm->last_acked); + int req_completed = (req->to_send == 0 && req->to_recv == 0); + struct pp_packet *pp = NULL; + + ucc_assert(comm->recv_drop_packet_in_progress == false); + ucc_assert(req->to_send >= 0); + + /* When do we need to perform reliability protocol: + 1. Always in the end of the window + 2. For the zcopy case: in the end of collective, because we can't signal completion + before made sure that children received the data - user can modify buffer */ + + ucc_assert(num_free_win >= 0); + + if (!num_free_win || (req->proto == MCAST_PROTO_ZCOPY && req_completed)) { + status = ucc_tl_mlx5_mcast_reliable(comm); + if (UCC_OK != status) { + return status; + } + + comm->n_mcast_reliable++; + + for (;comm->last_acked < comm->psn; comm->last_acked++) { + pp = comm->r_window[comm->last_acked & (wsize-1)]; + ucc_assert(pp != &comm->dummy_packet); + comm->r_window[comm->last_acked & (wsize-1)] = &comm->dummy_packet; + + pp->context = 0; + ucc_list_add_tail(&comm->bpool, &pp->super); + } + + if (!req_completed) { + status = ucc_tl_mlx5_mcast_prepare_reliable(comm, req, req->root); + if (ucc_unlikely(UCC_OK != status)) { + return status; + } + } + } + + return UCC_OK; +} + +static inline ucc_status_t ucc_tl_mlx5_mcast_do_bcast(ucc_tl_mlx5_mcast_coll_req_t *req) +{ + ucc_status_t status = UCC_OK; + ucc_tl_mlx5_mcast_coll_comm_t *comm = req->comm; + int zcopy = req->proto != MCAST_PROTO_EAGER; + int wsize = comm->wsize; + int num_free_win; + int num_sent; + int to_send; + int to_recv; + int to_recv_left; + int pending_q_size; + + if (ucc_unlikely(comm->recv_drop_packet_in_progress)) { + /* wait till parent resend the dropped packet */ + return UCC_INPROGRESS; + } + + if (comm->reliable_in_progress) { + /* wait till all the children send their ACK for current window */ + status = ucc_tl_mlx5_mcast_r_window_recycle(comm, req); + if (UCC_OK != status) { + return status; + } + } + + if (req->to_send || req->to_recv) { + num_free_win = wsize - (comm->psn - comm->last_acked); + + /* Send data if i'm root and there is a space in the window */ + if (num_free_win && req->am_root) { + num_sent = req->num_packets - req->to_send; + ucc_assert(req->to_send > 0); + ucc_assert(req->first_send_psn + num_sent < comm->last_acked + wsize); + if (req->first_send_psn + num_sent < comm->last_acked + wsize && + req->to_send) { + /* How many to send: either all that are left (if they fit into window) or + up to the window limit */ + to_send = ucc_min(req->to_send, + comm->last_acked + wsize - (req->first_send_psn + num_sent)); + ucc_tl_mlx5_mcast_send(comm, req, to_send, zcopy); + + num_free_win = wsize - (comm->psn - comm->last_acked); + } + } + + status = ucc_tl_mlx5_mcast_prepare_reliable(comm, req, req->root); + if (ucc_unlikely(UCC_OK != status)) { + return status; + } + + if (num_free_win && req->to_recv) { + /* How many to recv: either all that are left or up to the window limit. */ + pending_q_size = 0; + to_recv = ucc_min(num_free_win, req->to_recv); + to_recv_left = ucc_tl_mlx5_mcast_recv(comm, req, to_recv, &pending_q_size); + + if (to_recv == to_recv_left) { + /* We didn't receive anything: increase the stalled counter and get ready for + drop event */ + if (comm->stalled++ >= DROP_THRESHOLD) { + + tl_trace(comm->lib, "Did not receive the packet with psn in" + " current window range, so get ready for drop" + " event. pending_q_size %d current comm psn %d" + " last_acked psn %d stall threshold %d ", + pending_q_size, comm->psn, comm->last_acked, + DROP_THRESHOLD); + + status = ucc_tl_mlx5_mcast_bcast_check_drop(comm, req); + if (UCC_INPROGRESS == status) { + return status; + } + } + } else if (to_recv_left < 0) { + /* a failure happend during cq polling */ + return UCC_ERR_NO_MESSAGE; + } else { + comm->stalled = 0; + comm->timer = 0; + } + } + + /* This function will check if we have to do a round of reliability protocol */ + status = ucc_tl_mlx5_mcast_r_window_recycle(comm, req); + if (UCC_OK != status) { + return status; + } + } + + if (req->to_send || req->to_recv || (zcopy && comm->psn != comm->last_acked)) { + return UCC_INPROGRESS; + } else { + return status; + } +} + + +ucc_status_t ucc_tl_mlx5_mcast_test(ucc_tl_mlx5_mcast_coll_req_t* req) +{ + ucc_status_t status = UCC_OK; + + ucc_assert(req->comm->psn >= req->start_psn); + + status = ucc_tl_mlx5_mcast_do_bcast(req); + if (UCC_INPROGRESS != status) { + ucc_assert(req->comm->ctx != NULL); + ucc_tl_mlx5_mcast_mem_deregister(req->comm->ctx, req->rreg); + req->rreg = NULL; + } + + return status; +} + +static inline ucc_status_t ucc_tl_mlx5_mcast_prepare_bcast(void* buf, size_t size, ucc_rank_t root, + ucc_tl_mlx5_mcast_coll_comm_t *comm, + ucc_tl_mlx5_mcast_coll_req_t *req) +{ + ucc_status_t status; + ucc_tl_mlx5_mcast_reg_t *reg; + + req->comm = comm; + req->ptr = buf; + req->length = size; + req->root = root; + req->am_root = (root == comm->rank); + req->mr = comm->pp_mr; + req->rreg = NULL; + req->proto = (req->length < comm->max_eager) ? MCAST_PROTO_EAGER : MCAST_PROTO_ZCOPY; + + status = ucc_tl_mlx5_mcast_prepare_reliable(comm, req, req->root); + if (ucc_unlikely(UCC_OK != status)) { + return status; + } + + if (req->am_root) { + if (req->proto != MCAST_PROTO_EAGER) { + status = ucc_tl_mlx5_mcast_mem_register(comm->ctx, req->ptr, req->length, ®); + if (UCC_OK != status) { + return status; + } + req->rreg = reg; + req->mr = reg->mr; + } + } + + req->offset = 0; + req->start_psn = comm->last_psn; + req->num_packets = ucc_max(ucc_div_round_up(req->length, comm->max_per_packet), 1); + req->last_pkt_len = req->length - (req->num_packets - 1)*comm->max_per_packet; + + ucc_assert(req->last_pkt_len > 0 && req->last_pkt_len <= comm->max_per_packet); + + comm->last_psn += req->num_packets; + req->first_send_psn = req->start_psn; + req->to_send = req->am_root ? req->num_packets : 0; + req->to_recv = req->am_root ? 0 : req->num_packets; + + return UCC_OK; } -ucc_status_t mcast_coll_do_bcast(void* buf, size_t size, ucc_rank_t root, void *mr, /* NOLINT */ - ucc_tl_mlx5_mcast_coll_comm_t *comm, /* NOLINT */ - ucc_tl_mlx5_mcast_coll_req_t **task_req_handle /* NOLINT */) +ucc_status_t ucc_tl_mlx5_mcast_coll_do_bcast(void* buf, size_t size, ucc_rank_t root, + ucc_tl_mlx5_mcast_coll_comm_t *comm, + ucc_tl_mlx5_mcast_coll_req_t **task_req_handle) { - return UCC_ERR_NOT_SUPPORTED; + ucc_status_t status; + ucc_tl_mlx5_mcast_coll_req_t *req; + + tl_trace(comm->lib, "MCAST bcast start, buf %p, size %ld, root %d, comm %d, " + "comm_size %d, am_i_root %d comm->psn = %d \n", + buf, size, root, comm->comm_id, comm->commsize, comm->rank == + root, comm->psn ); + + req = ucc_calloc(1, sizeof(ucc_tl_mlx5_mcast_coll_req_t), "mcast_req"); + if (!req) { + return UCC_ERR_NO_MEMORY; + } + + status = ucc_tl_mlx5_mcast_prepare_bcast(buf, size, root, comm, req); + if (UCC_OK != status) { + ucc_free(req); + return status; + } + + status = UCC_INPROGRESS; + *task_req_handle = req; + + return status; } ucc_status_t ucc_tl_mlx5_mcast_bcast_start(ucc_coll_task_t *coll_task) @@ -35,8 +261,8 @@ ucc_status_t ucc_tl_mlx5_mcast_bcast_start(ucc_coll_task_t *coll_task) task->bcast_mcast.req_handle = NULL; - status = mcast_coll_do_bcast(buf, data_size, root, NULL, comm, - &task->bcast_mcast.req_handle); + status = ucc_tl_mlx5_mcast_coll_do_bcast(buf, data_size, root, comm, + &task->bcast_mcast.req_handle); if (status < 0) { tl_error(UCC_TASK_LIB(task), "mcast_coll_do_bcast failed:%d", status); coll_task->status = status; @@ -50,23 +276,16 @@ ucc_status_t ucc_tl_mlx5_mcast_bcast_start(ucc_coll_task_t *coll_task) void ucc_tl_mlx5_mcast_collective_progress(ucc_coll_task_t *coll_task) { - ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t); - ucc_status_t status = UCC_OK; - ucc_tl_mlx5_mcast_coll_req_t *req = task->bcast_mcast.req_handle; + ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t); + ucc_tl_mlx5_mcast_coll_req_t *req = task->bcast_mcast.req_handle; if (req != NULL) { - status = ucc_tl_mlx5_mcast_test(req); - if (UCC_OK == status) { - coll_task->status = UCC_OK; - ucc_free(req); - task->bcast_mcast.req_handle = NULL; - } + coll_task->status = ucc_tl_mlx5_mcast_test(req); } } ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task) { - task->super.post = ucc_tl_mlx5_mcast_bcast_start; task->super.progress = ucc_tl_mlx5_mcast_collective_progress; diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h index 47ddc301aa..74385b1573 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_context.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_context.c index ad32c459b0..192000ee86 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_context.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_context.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -51,12 +51,20 @@ ucc_status_t ucc_tl_mlx5_mcast_context_init(ucc_tl_mlx5_mcast_context_t *cont int ib_valid; const char *dst; + mlx5_ctx = ucc_container_of(context, ucc_tl_mlx5_context_t, mcast); + lib = mlx5_ctx->super.super.lib; + + context->mcast_enabled = mcast_ctx_conf->mcast_enabled; + + if (!mcast_ctx_conf->mcast_enabled) { + tl_debug(lib, "Mcast is disabled by the user"); + return UCC_ERR_NO_RESOURCE; + } + ctx = &(context->mcast_context); memset(ctx, 0, sizeof(ucc_tl_mlx5_mcast_coll_context_t)); memcpy(&ctx->params, mcast_ctx_conf, sizeof(ucc_tl_mlx5_mcast_ctx_params_t)); - mlx5_ctx = ucc_container_of(context, ucc_tl_mlx5_context_t, mcast); - lib = mlx5_ctx->super.super.lib; ctx->lib = lib; /* TODO unify all the contexts under TL mlx5 */ @@ -239,13 +247,55 @@ ucc_status_t ucc_tl_mlx5_mcast_context_init(ucc_tl_mlx5_mcast_context_t *cont error: if (ctx->pd) { ibv_dealloc_pd(ctx->pd); + ctx->pd = NULL; } if (ctx->id) { rdma_destroy_id(ctx->id); + ctx->id = NULL; } if (ctx->channel) { rdma_destroy_event_channel(ctx->channel); + ctx->channel = NULL; } return status; } + +ucc_status_t ucc_tl_mlx5_mcast_clean_ctx(ucc_tl_mlx5_mcast_coll_context_t *ctx) +{ + tl_debug(ctx->lib, "cleaning mcast ctx: %p", ctx); + + if (ctx == NULL) return UCC_OK; + + if (ctx->rcache) { + ucc_rcache_destroy(ctx->rcache); + ctx->rcache = NULL; + } + + if (ctx->pd) { + if (ibv_dealloc_pd(ctx->pd)) { + tl_error(ctx->lib, "ibv_dealloc_pd failed errno %d", errno); + return UCC_ERR_NO_RESOURCE; + } + ctx->pd = NULL; + } + + if (ctx->id && rdma_destroy_id(ctx->id)) { + tl_error(ctx->lib, "rdma_destroy_id failed errno %d", errno); + return UCC_ERR_NO_RESOURCE; + } + + ctx->id = NULL; + + if (ctx->channel) { + rdma_destroy_event_channel(ctx->channel); + ctx->channel = NULL; + } + + if (ctx->devname && !strcmp(ctx->params.ib_dev_name, "")) { + ucc_free(ctx->devname); + ctx->devname = NULL; + } + + return UCC_OK; +} diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c index 8c52a63c73..81d142b3a1 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -417,15 +417,21 @@ ucc_status_t ucc_tl_mlx5_fini_mcast_group(ucc_tl_mlx5_mcast_coll_context_t *ctx, ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm) { - int ret; - ucc_status_t status; + ucc_tl_mlx5_mcast_context_t *mcast_ctx = ucc_container_of(comm->ctx, ucc_tl_mlx5_mcast_context_t, mcast_context); + ucc_tl_mlx5_context_t *mlx5_ctx = ucc_container_of(mcast_ctx, ucc_tl_mlx5_context_t, mcast); + ucc_context_h context = mlx5_ctx->super.super.ucc_context; + int ret; + ucc_status_t status; tl_debug(comm->lib, "cleaning mcast comm: %p, id %d, mlid %x", comm, comm->comm_id, comm->mcast_lid); - if (UCC_OK != (status = ucc_tl_mlx5_mcast_reliable(comm))) { - // TODO handle (UCC_INPROGRESS == ret) - tl_error(comm->lib, "couldn't clean mcast team: relibality progress status %d", + while (UCC_INPROGRESS == (status = ucc_tl_mlx5_mcast_reliable(comm))) { + ucc_context_progress(context); + } + + if (UCC_OK != status) { + tl_error(comm->lib, "failed to clean mcast team: relibality progress status %d", status); return status; } @@ -529,33 +535,3 @@ ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm) return UCC_OK; } -ucc_status_t ucc_tl_mlx5_clean_mcast_ctx(ucc_tl_mlx5_mcast_coll_context_t *ctx) -{ - tl_debug(ctx->lib, "cleaning mcast ctx: %p", ctx); - - if (ctx->rcache) { - ucc_rcache_destroy(ctx->rcache); - } - - if (ctx->pd) { - if (ibv_dealloc_pd(ctx->pd)) { - tl_error(ctx->lib, "ibv_dealloc_pd failed errno %d", errno); - return UCC_ERR_NO_RESOURCE; - } - } - - if (rdma_destroy_id(ctx->id)) { - tl_error(ctx->lib, "rdma_destroy_id failed errno %d", errno); - return UCC_ERR_NO_RESOURCE; - } - - rdma_destroy_event_channel(ctx->channel); - - if (!strcmp(ctx->params.ib_dev_name, "")) { - ucc_free(ctx->devname); - } - - ucc_free(ctx); - - return UCC_OK; -} diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h index 05037e495f..427039316d 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -319,7 +319,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_reliable(ucc_tl_mlx5_mcast_coll_com } } - status = ucc_tl_mlx5_mcast_check_nack_requests_all(comm); + status = ucc_tl_mlx5_mcast_check_nack_requests(comm, UINT32_MAX); if (UCC_OK != status) { return status; } @@ -365,4 +365,12 @@ ucc_status_t ucc_tl_mlx5_mcast_setup_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx, ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm); +ucc_status_t ucc_tl_mlx5_mcast_join_mcast_post(ucc_tl_mlx5_mcast_coll_context_t *ctx, + struct sockaddr_in6 *net_addr, + int is_root); + +ucc_status_t ucc_tl_mlx5_mcast_join_mcast_test(ucc_tl_mlx5_mcast_coll_context_t *ctx, + struct rdma_cm_event **event, + int is_root); + #endif /* TL_MLX5_MCAST_HELPER_H_ */ diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c new file mode 100644 index 0000000000..a201944ecf --- /dev/null +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c @@ -0,0 +1,378 @@ +/** + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include "tl_mlx5_mcast_progress.h" + +static ucc_status_t ucc_tl_mlx5_mcast_recv_completion(ucc_tl_mlx5_mcast_p2p_completion_obj_t *obj); + +static ucc_status_t ucc_tl_mlx5_mcast_send_completion(ucc_tl_mlx5_mcast_p2p_completion_obj_t *obj); + +static ucc_status_t ucc_tl_mlx5_mcast_dummy_completion(ucc_tl_mlx5_mcast_p2p_completion_obj_t *obj) // NOLINT +{ + return UCC_OK; +} + +static ucc_tl_mlx5_mcast_p2p_completion_obj_t dummy_completion_obj = { + .compl_cb = ucc_tl_mlx5_mcast_dummy_completion, +}; + +static inline ucc_status_t ucc_tl_mlx5_mcast_resend_packet_reliable(ucc_tl_mlx5_mcast_coll_comm_t *comm, + int p2p_pkt_id) +{ + uint32_t psn = comm->p2p_pkt[p2p_pkt_id].psn; + struct pp_packet *pp = comm->r_window[psn % comm->wsize]; + ucc_status_t status; + + ucc_assert(pp->psn == psn); + + tl_trace(comm->lib, "[comm %d, rank %d] Send data NACK: to %d, psn %d, context %ld\n", + comm->comm_id, comm->rank, + comm->p2p_pkt[p2p_pkt_id].from, psn, pp->context); + + status = comm->params.p2p_iface.send_nb((void*) (pp->context ? pp->context : pp->buf), + pp->length, comm->p2p_pkt[p2p_pkt_id].from, + comm->p2p_ctx, &dummy_completion_obj); + if (status < 0) { + return status; + } + + status = comm->params.p2p_iface.recv_nb(&comm->p2p_pkt[p2p_pkt_id], + sizeof(struct packet), comm->p2p_pkt[p2p_pkt_id].from, + comm->p2p_ctx, GET_COMPL_OBJ(comm, + ucc_tl_mlx5_mcast_recv_completion, p2p_pkt_id, NULL)); + if (status < 0) { + return status; + } + + return UCC_OK; +} + +ucc_status_t ucc_tl_mlx5_mcast_check_nack_requests(ucc_tl_mlx5_mcast_coll_comm_t *comm, uint32_t psn) +{ + ucc_status_t status = UCC_OK; + int i; + struct pp_packet *pp; + + if (!comm->nack_requests) { + return UCC_OK; + } + + if (psn != UINT32_MAX) { + for (i=0; ichild_n; i++) { + if (psn == comm->p2p_pkt[i].psn && + comm->p2p_pkt[i].type == MCAST_P2P_NEED_NACK_SEND) { + status = ucc_tl_mlx5_mcast_resend_packet_reliable(comm, i); + if (status != UCC_OK) { + break; + } + comm->p2p_pkt[i].type = MCAST_P2P_ACK; + comm->nack_requests--; + } + } + } else { + for (i=0; ichild_n; i++){ + if (comm->p2p_pkt[i].type == MCAST_P2P_NEED_NACK_SEND) { + psn = comm->p2p_pkt[i].psn; + pp = comm->r_window[psn % comm->wsize]; + if (psn == pp->psn) { + status = ucc_tl_mlx5_mcast_resend_packet_reliable(comm, i); + if (status < 0) { + break; + } + comm->p2p_pkt[i].type = MCAST_P2P_ACK; + comm->nack_requests--; + } + } + } + } + + return status; +} + +static inline int ucc_tl_mlx5_mcast_find_nack_psn(ucc_tl_mlx5_mcast_coll_comm_t* comm, + ucc_tl_mlx5_mcast_coll_req_t *req) +{ + int psn = ucc_max(comm->last_acked, req->start_psn); + int max_search_psn = ucc_min(req->start_psn + req->num_packets, + comm->last_acked + comm->wsize + 1); + + for (; psn < max_search_psn; psn++) { + if (!PSN_RECEIVED(psn, comm)) { + break; + } + } + + ucc_assert(psn < max_search_psn); + + return psn; +} + +static inline ucc_rank_t ucc_tl_mlx5_mcast_get_nack_parent(ucc_tl_mlx5_mcast_coll_req_t *req) +{ + return req->parent; +} + +/* When parent resend the lost packet to a child, this function is called at child side */ +static ucc_status_t ucc_tl_mlx5_mcast_recv_data_completion(ucc_tl_mlx5_mcast_p2p_completion_obj_t *obj) +{ + ucc_status_t status = UCC_OK; + ucc_tl_mlx5_mcast_coll_comm_t *comm = (ucc_tl_mlx5_mcast_coll_comm_t *)obj->data[0]; + struct pp_packet *pp = (struct pp_packet *)obj->data[1]; + ucc_tl_mlx5_mcast_coll_req_t *req = (ucc_tl_mlx5_mcast_coll_req_t *)obj->data[2]; + void *dest; + + tl_trace(comm->lib, "[comm %d, rank %d] Recved data psn %d", comm->comm_id, comm->rank, pp->psn); + + dest = req->ptr + PSN_TO_RECV_OFFSET(pp->psn, req, comm); + memcpy(dest, (void*) pp->buf, pp->length); + req->to_recv--; + comm->r_window[pp->psn % comm->wsize] = pp; + + status = ucc_tl_mlx5_mcast_check_nack_requests(comm, pp->psn); + if (status < 0) { + return status; + } + + comm->psn++; + comm->recv_drop_packet_in_progress = false; + + return status; +} + +static inline ucc_status_t ucc_tl_mlx5_mcast_reliable_send_NACK(ucc_tl_mlx5_mcast_coll_comm_t* comm, + ucc_tl_mlx5_mcast_coll_req_t *req) +{ + struct pp_packet *pp; + ucc_rank_t parent; + ucc_status_t status; + + struct packet p = { + .type = MCAST_P2P_NACK, + .psn = ucc_tl_mlx5_mcast_find_nack_psn(comm, req), + .from = comm->rank, + .comm_id = comm->comm_id, + }; + + parent = ucc_tl_mlx5_mcast_get_nack_parent(req); + + comm->nacks_counter++; + + status = comm->params.p2p_iface.send_nb(&p, sizeof(struct packet), parent, + comm->p2p_ctx, &dummy_completion_obj); + if (status < 0) { + return status; + } + + tl_trace(comm->lib, "[comm %d, rank %d] Sent NAK : parent %d, psn %d", + comm->comm_id, comm->rank, parent, p.psn); + + // Prepare to obtain the data. + pp = ucc_tl_mlx5_mcast_buf_get_free(comm); + pp->psn = p.psn; + pp->length = PSN_TO_RECV_LEN(pp->psn, req, comm); + + comm->recv_drop_packet_in_progress = true; + + status = comm->params.p2p_iface.recv_nb((void*) pp->buf, + pp->length, parent, + comm->p2p_ctx, GET_COMPL_OBJ(comm, + ucc_tl_mlx5_mcast_recv_data_completion, pp, req)); + if (status < 0) { + return status; + } + + return UCC_INPROGRESS; +} + +ucc_status_t ucc_tl_mlx5_mcast_reliable_send(ucc_tl_mlx5_mcast_coll_comm_t *comm) +{ + ucc_rank_t i; + ucc_rank_t parent; + ucc_status_t status; + + tl_trace(comm->lib, "comm %p, psn %d, last_acked %d, n_parent %d", + comm, comm->psn, comm->last_acked, comm->parent_n); + + ucc_assert(!comm->reliable_in_progress); + + for (i=0; iparent_n; i++) { + parent = comm->parents[i]; + comm->p2p_spkt[i].type = MCAST_P2P_ACK; + comm->p2p_spkt[i].psn = comm->last_acked + comm->wsize; + comm->p2p_spkt[i].comm_id = comm->comm_id; + + tl_trace(comm->lib, "rank %d, Posting SEND to parent %d, n_parent %d, psn %d", + comm->rank, parent, comm->parent_n, comm->psn); + + status = comm->params.p2p_iface.send_nb(&comm->p2p_spkt[i], + sizeof(struct packet), parent, + comm->p2p_ctx, GET_COMPL_OBJ(comm, + ucc_tl_mlx5_mcast_send_completion, i, NULL)); + if (status < 0) { + return status; + } + } + + return UCC_OK; +} + +static ucc_status_t ucc_tl_mlx5_mcast_recv_completion(ucc_tl_mlx5_mcast_p2p_completion_obj_t *obj) +{ + ucc_tl_mlx5_mcast_coll_comm_t *comm = (ucc_tl_mlx5_mcast_coll_comm_t*)obj->data[0]; + int pkt_id = (int)obj->data[1]; + uint32_t psn; + struct pp_packet *pp; + ucc_status_t status; + + ucc_assert(comm->comm_id == comm->p2p_pkt[pkt_id].comm_id); + + if (comm->p2p_pkt[pkt_id].type != MCAST_P2P_ACK) { + ucc_assert(comm->p2p_pkt[pkt_id].type == MCAST_P2P_NACK); + psn = comm->p2p_pkt[pkt_id].psn; + pp = comm->r_window[psn % comm->wsize]; + + tl_trace(comm->lib, "[comm %d, rank %d] Got NACK: from %d, psn %d, avail %d", + comm->comm_id, comm->rank, + comm->p2p_pkt[pkt_id].from, psn, pp->psn == psn); + + if (pp->psn == psn) { + status = ucc_tl_mlx5_mcast_resend_packet_reliable(comm, pkt_id); + if (status < 0) { + return status; + } + } else { + comm->p2p_pkt[pkt_id].type = MCAST_P2P_NEED_NACK_SEND; + comm->nack_requests++; + } + + } else { + comm->racks_n++; + } + + ucc_mpool_put(obj); /* return the completion object back to the mem pool compl_objects_mp */ + + return UCC_OK; +} + +static ucc_status_t ucc_tl_mlx5_mcast_send_completion(ucc_tl_mlx5_mcast_p2p_completion_obj_t *obj) +{ + ucc_tl_mlx5_mcast_coll_comm_t *comm = (ucc_tl_mlx5_mcast_coll_comm_t*)obj->data[0]; + + comm->sacks_n++; + ucc_mpool_put(obj); + return UCC_OK; +} + +static inline int add_uniq(ucc_rank_t *arr, uint32_t *len, ucc_rank_t value) +{ + int i; + + for (i=0; i<(*len); i++) { + if (arr[i] == value) { + return 0; + } + } + + arr[*len] = value; + (*len)++; + return 1; +} + +ucc_status_t ucc_tl_mlx5_mcast_prepare_reliable(ucc_tl_mlx5_mcast_coll_comm_t *comm, + ucc_tl_mlx5_mcast_coll_req_t *req, + ucc_rank_t root) +{ + ucc_rank_t mask = 1; + ucc_rank_t vrank = TO_VIRTUAL(comm->rank, comm->commsize, root); + ucc_rank_t child; + ucc_status_t status; + + ucc_assert(comm->commsize <= pow(2, MAX_COMM_POW2)); + + while (mask < comm->commsize) { + if (vrank & mask) { + req->parent = TO_ORIGINAL((vrank ^ mask), comm->commsize, root); + add_uniq(comm->parents, &comm->parent_n, req->parent); + break; + } else { + child = vrank ^ mask; + if (child < comm->commsize) { + child = TO_ORIGINAL(child, comm->commsize, root); + if (add_uniq(comm->children, &comm->child_n, child)) { + tl_trace(comm->lib, "rank %d, Posting RECV from child %d, n_child %d, psn %d", + comm->rank, child, comm->child_n, comm->psn); + + status = comm->params.p2p_iface.recv_nb(&comm->p2p_pkt[comm->child_n - 1], + sizeof(struct packet), child, + comm->p2p_ctx, GET_COMPL_OBJ(comm, + ucc_tl_mlx5_mcast_recv_completion, comm->child_n - 1, req)); + if (status < 0) { + return status; + } + } + } + } + + mask <<= 1; + } + + return UCC_OK; +} + +static inline uint64_t ucc_tl_mlx5_mcast_get_timer(void) +{ + double t_second = ucc_get_time(); + return (uint64_t) (t_second * 1000000); +} + +ucc_status_t ucc_tl_mlx5_mcast_bcast_check_drop(ucc_tl_mlx5_mcast_coll_comm_t *comm, + ucc_tl_mlx5_mcast_coll_req_t *req) +{ + ucc_status_t status = UCC_OK; + + if (comm->timer == 0) { + comm->timer = ucc_tl_mlx5_mcast_get_timer(); + } else { + if (ucc_tl_mlx5_mcast_get_timer() - comm->timer >= comm->ctx->params.timeout) { + tl_trace(comm->lib, "[REL] time out %d", comm->psn); + status = ucc_tl_mlx5_mcast_reliable_send_NACK(comm, req); + comm->timer = 0; + } + } + + return status; +} + +ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *comm, + ucc_tl_mlx5_mcast_coll_req_t *req, + struct pp_packet* pp) +{ + ucc_status_t status = UCC_OK; + void *dest; + ucc_assert(pp->psn >= req->start_psn && + pp->psn < req->start_psn + req->num_packets); + + ucc_assert(pp->length == PSN_TO_RECV_LEN(pp->psn, req, comm)); + ucc_assert(pp->context == 0); + + if (pp->length > 0 ) { + dest = req->ptr + PSN_TO_RECV_OFFSET(pp->psn, req, comm); + memcpy(dest, (void*) pp->buf, pp->length); + } + + comm->r_window[pp->psn & (comm->wsize-1)] = pp; + status = ucc_tl_mlx5_mcast_check_nack_requests(comm, pp->psn); + if (status < 0) { + return status; + } + + req->to_recv--; + comm->psn++; + ucc_assert(comm->recv_drop_packet_in_progress == false); + + return status; +} + diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.h index da30a4b1c0..1bceb89976 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -42,23 +42,23 @@ _req; \ }) -int ucc_tl_mlx5_mcast_prepare_reliable(ucc_tl_mlx5_mcast_coll_comm_t *comm, - ucc_tl_mlx5_mcast_coll_req_t *req, - ucc_rank_t root); +ucc_status_t ucc_tl_mlx5_mcast_prepare_reliable(ucc_tl_mlx5_mcast_coll_comm_t *comm, + ucc_tl_mlx5_mcast_coll_req_t *req, + ucc_rank_t root); ucc_status_t ucc_tl_mlx5_mcast_bcast_check_drop(ucc_tl_mlx5_mcast_coll_comm_t *comm, ucc_tl_mlx5_mcast_coll_req_t *req); ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *comm, - ucc_tl_mlx5_mcast_coll_req_t *req, - struct pp_packet* pp); + ucc_tl_mlx5_mcast_coll_req_t *req, + struct pp_packet* pp); ucc_status_t ucc_tl_mlx5_mcast_check_nack_requests(ucc_tl_mlx5_mcast_coll_comm_t *comm, - uint32_t psn); + uint32_t psn); ucc_status_t ucc_tl_mlx5_mcast_reliable_send(ucc_tl_mlx5_mcast_coll_comm_t* comm); -ucc_status_t ucc_tl_mlx5_mcast_check_nack_requests_all(ucc_tl_mlx5_mcast_coll_comm_t* comm); +ucc_status_t ucc_tl_mlx5_mcast_check_nack_requests(ucc_tl_mlx5_mcast_coll_comm_t* comm, uint32_t psn); #endif /* ifndef TL_MLX5_MCAST_PROGRESS_H_ */ diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_rcache.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_rcache.c index 75c62ac81f..47f73e485b 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_rcache.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_rcache.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -109,9 +109,8 @@ ucc_tl_mlx5_mcast_mem_register(ucc_tl_mlx5_mcast_coll_context_t *ctx, return UCC_OK; } -ucc_status_t -ucc_tl_mlx5_mcast_mem_deregister(ucc_tl_mlx5_mcast_coll_context_t *ctx, - ucc_tl_mlx5_mcast_reg_t *reg) +void ucc_tl_mlx5_mcast_mem_deregister(ucc_tl_mlx5_mcast_coll_context_t *ctx, + ucc_tl_mlx5_mcast_reg_t *reg) { ucc_tl_mlx5_mcast_rcache_region_t *region; ucc_rcache_t *rcache; @@ -119,18 +118,16 @@ ucc_tl_mlx5_mcast_mem_deregister(ucc_tl_mlx5_mcast_coll_context_t *ctx, rcache = ctx->rcache; if (reg == NULL) { - return UCC_OK; + return; } ucc_assert(rcache != NULL); tl_trace(ctx->lib, "memory deregister mr %p", reg->mr); region = ucc_container_of(reg, ucc_tl_mlx5_mcast_rcache_region_t, reg); ucc_rcache_region_put(rcache, ®ion->super); - - return UCC_OK; } -static ucc_rcache_ops_t ucc_rcache_ops = { +static ucc_rcache_ops_t ucc_tl_mlx5_rcache_ops = { .mem_reg = ucc_tl_mlx5_mcast_rcache_mem_reg_cb, .mem_dereg = ucc_tl_mlx5_mcast_rcache_mem_dereg_cb, .dump_region = ucc_tl_mlx5_mcast_rcache_dump_region_cb @@ -140,15 +137,12 @@ ucc_status_t ucc_tl_mlx5_mcast_setup_rcache(ucc_tl_mlx5_mcast_coll_context_t *ct { ucc_rcache_params_t rcache_params; - rcache_params.ucm_event_priority = 1000; - rcache_params.max_regions = ULONG_MAX; - rcache_params.max_size = SIZE_MAX; + ucc_rcache_set_default_params(&rcache_params); rcache_params.region_struct_size = sizeof(ucc_tl_mlx5_mcast_rcache_region_t); + rcache_params.context = ctx; + rcache_params.ops = &ucc_tl_mlx5_rcache_ops; rcache_params.ucm_events = UCM_EVENT_VM_UNMAPPED | UCM_EVENT_MEM_TYPE_FREE; - rcache_params.context = ctx; - rcache_params.ops = &ucc_rcache_ops; - rcache_params.flags = 0; - return ucc_rcache_create(&rcache_params, "MCAST", &ctx->rcache); + return ucc_rcache_create(&rcache_params, "MLX5_MCAST", &ctx->rcache); } diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_rcache.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_rcache.h index e1836704ad..da90f562a1 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_rcache.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_rcache.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -13,5 +13,5 @@ ucc_status_t ucc_tl_mlx5_mcast_mem_register(ucc_tl_mlx5_mcast_coll_context_t *ctx, void *addr, size_t length, ucc_tl_mlx5_mcast_reg_t **reg); -ucc_status_t ucc_tl_mlx5_mcast_mem_deregister(ucc_tl_mlx5_mcast_coll_context_t *ctx, - ucc_tl_mlx5_mcast_reg_t *reg); +void ucc_tl_mlx5_mcast_mem_deregister(ucc_tl_mlx5_mcast_coll_context_t *ctx, + ucc_tl_mlx5_mcast_reg_t *reg); diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c index f56bc3c1a1..1821b4375c 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -9,14 +9,184 @@ #include "tl_mlx5_mcast_coll.h" #include "coll_score/ucc_coll_score.h" #include "tl_mlx5_mcast_helper.h" +#include "p2p/ucc_tl_mlx5_mcast_p2p.h" +#include "mcast/tl_mlx5_mcast_helper.h" + +static ucc_status_t ucc_tl_mlx5_mcast_service_bcast_post(void *arg, void *buf, size_t size, ucc_rank_t root, + ucc_service_coll_req_t **bcast_req) +{ + ucc_tl_mlx5_mcast_oob_p2p_context_t *ctx = (ucc_tl_mlx5_mcast_oob_p2p_context_t *)arg; + ucc_status_t status = UCC_OK; + ucc_team_t *team = ctx->base_team; + ucc_subset_t subset = ctx->subset; + ucc_service_coll_req_t *req = NULL; + + status = ucc_service_bcast(team, buf, size, root, subset, &req); + if (ucc_unlikely(UCC_OK != status)) { + tl_error(ctx->base_ctx->lib, "tl service mcast bcast failed"); + return status; + } + + *bcast_req = req; + + return status; +} -ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context, /* NOLINT */ - ucc_tl_mlx5_mcast_team_t **mcast_team, /* NOLINT */ - ucc_tl_mlx5_mcast_context_t *ctx, /* NOLINT */ - const ucc_base_team_params_t *params, /* NOLINT */ - ucc_tl_mlx5_mcast_coll_comm_init_spec_t *mcast_conf /* NOLINT */) +static ucc_status_t ucc_tl_mlx5_mcast_service_bcast_test(ucc_service_coll_req_t *req) { + ucc_status_t status = UCC_OK; + + status = ucc_service_coll_test(req); + + if (UCC_INPROGRESS != status) { + ucc_service_coll_finalize(req); + } + + return status; +} + +ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context, + ucc_tl_mlx5_mcast_team_t **mcast_team, + ucc_tl_mlx5_mcast_context_t *ctx, + const ucc_base_team_params_t *team_params, + ucc_tl_mlx5_mcast_coll_comm_init_spec_t *mcast_conf) +{ + ucc_status_t status; + ucc_subset_t set; + ucc_tl_mlx5_mcast_coll_comm_init_spec_t comm_spec = *mcast_conf; + ucc_tl_mlx5_mcast_coll_context_t *mcast_context = &(ctx->mcast_context); + ucc_tl_mlx5_mcast_coll_comm_init_spec_t *conf_params = &comm_spec; + ucc_context_t *context = base_context->ucc_context; + ucc_tl_mlx5_mcast_team_t *new_mcast_team; + ucc_tl_mlx5_mcast_oob_p2p_context_t *oob_p2p_ctx; + ucc_tl_mlx5_mcast_coll_comm_t *comm; + int i; + + if (!ctx->mcast_enabled || NULL == mcast_context) { + tl_debug(base_context->lib, + "mcast context not available, base_context = %p", + base_context ); + return UCC_ERR_NO_RESOURCE; + } + + new_mcast_team = ucc_calloc(1, sizeof(ucc_tl_mlx5_mcast_team_t), "new_mcast_team"); + + if (!new_mcast_team) { + return UCC_ERR_NO_MEMORY; + } + + new_mcast_team->mcast_context = ctx; + + /* init p2p interface */ + conf_params->p2p_iface.send_nb = ucc_tl_mlx5_mcast_p2p_send_nb; + conf_params->p2p_iface.recv_nb = ucc_tl_mlx5_mcast_p2p_recv_nb; + + oob_p2p_ctx = ucc_malloc(sizeof(ucc_tl_mlx5_mcast_oob_p2p_context_t), + "oob_p2p_ctx"); + if (!oob_p2p_ctx) { + ucc_free(new_mcast_team); + return UCC_ERR_NO_MEMORY; + } + + oob_p2p_ctx->base_ctx = context; + oob_p2p_ctx->base_team = team_params->team; + oob_p2p_ctx->my_team_rank = team_params->rank; + set.myrank = team_params->rank; + set.map = team_params->map; + oob_p2p_ctx->subset = set; + conf_params->oob = oob_p2p_ctx; + conf_params->sx_sge = 1; + conf_params->rx_sge = 2; + conf_params->scq_moderation = 64; + + comm = (ucc_tl_mlx5_mcast_coll_comm_t*) + ucc_calloc(1, sizeof(ucc_tl_mlx5_mcast_coll_comm_t) + + sizeof(struct pp_packet*)*(conf_params->wsize-1), + "ucc_tl_mlx5_mcast_coll_comm_t"); + if (!comm) { + ucc_free(oob_p2p_ctx); + ucc_free(new_mcast_team); + return UCC_ERR_NO_MEMORY; + } + + ucc_list_head_init(&comm->bpool); + ucc_list_head_init(&comm->pending_q); + + comm->bcast_post = ucc_tl_mlx5_mcast_service_bcast_post; + comm->bcast_test = ucc_tl_mlx5_mcast_service_bcast_test; + + memcpy(&comm->params, conf_params, sizeof(*conf_params)); + + comm->wsize = conf_params->wsize; + comm->max_eager = conf_params->max_eager; + comm->comm_id = team_params->id; + comm->ctx = mcast_context; + comm->grh_buf = (char *)ucc_malloc(GRH_LENGTH * sizeof(char), "grh_buf"); + if (!comm->grh_buf) { + status = UCC_ERR_NO_MEMORY; + goto cleanup; + } + + memset(comm->grh_buf, 0, GRH_LENGTH); + + comm->grh_mr = ibv_reg_mr(mcast_context->pd, comm->grh_buf, GRH_LENGTH, + IBV_ACCESS_REMOTE_WRITE | + IBV_ACCESS_LOCAL_WRITE); + if (!comm->grh_mr) { + tl_error(mcast_context->lib, "could not register memory for GRH, errno %d", errno); + status = UCC_ERR_NO_RESOURCE; + goto cleanup; + } + + comm->rcq = ibv_create_cq(mcast_context->ctx, comm->params.rx_depth, NULL, NULL, 0); + if (!comm->rcq) { + ibv_dereg_mr(comm->grh_mr); + tl_error(mcast_context->lib, "could not create recv cq, rx_depth %d, errno %d", + comm->params.rx_depth, errno); + status = UCC_ERR_NO_RESOURCE; + goto cleanup; + } + + comm->scq = ibv_create_cq(mcast_context->ctx, comm->params.sx_depth, NULL, NULL, 0); + if (!comm->scq) { + ibv_dereg_mr(comm->grh_mr); + ibv_destroy_cq(comm->rcq); + tl_error(mcast_context->lib, "could not create send cq, sx_depth %d, errno %d", + comm->params.sx_depth, errno); + status = UCC_ERR_NO_RESOURCE; + goto cleanup; + } + + comm->rank = team_params->rank; + comm->commsize = team_params->size; + comm->max_per_packet = mcast_context->mtu - GRH_LENGTH; + comm->last_acked = comm->last_psn = 0; + comm->racks_n = comm->sacks_n = 0; + comm->child_n = comm->parent_n = 0; + comm->p2p_ctx = conf_params->oob; + + memcpy(&comm->p2p, &conf_params->p2p_iface, + sizeof(ucc_tl_mlx5_mcast_p2p_interface_t)); + + comm->dummy_packet.psn = UINT32_MAX; + + for (i=0; i< comm->wsize; i++) { + comm->r_window[i] = &comm->dummy_packet; + } + + comm->lib = base_context->lib; + new_mcast_team->mcast_comm = comm; + *mcast_team = new_mcast_team; + + tl_debug(base_context->lib, "posted tl mcast team : %p", new_mcast_team); + return UCC_OK; + +cleanup: + ucc_free(comm); + ucc_free(new_mcast_team); + ucc_free(oob_p2p_ctx); + return status; } ucc_status_t ucc_tl_mlx5_mcast_coll_setup_comm_resources(ucc_tl_mlx5_mcast_coll_comm_t *comm) @@ -128,3 +298,295 @@ ucc_status_t ucc_tl_mlx5_mcast_coll_setup_comm_resources(ucc_tl_mlx5_mcast_coll_ ucc_tl_mlx5_clean_mcast_comm(comm); return status; } + +ucc_status_t ucc_tl_mlx5_mcast_team_test(ucc_base_team_t *team) +{ + ucc_tl_mlx5_team_t *tl_team = ucc_derived_of(team, ucc_tl_mlx5_team_t); + ucc_tl_mlx5_mcast_coll_comm_t *comm = tl_team->mcast->mcast_comm; + ucc_status_t status = UCC_OK; + struct sockaddr_in6 net_addr = {0,}; + ucc_tl_mlx5_mcast_join_info_t *data = NULL; + + if (comm->rank == 0) { + switch(tl_team->mcast_state) { + case TL_MLX5_TEAM_STATE_MCAST_INIT: + { + /* now it is time for rank 0 to call rdma_join_multicast() */ + net_addr.sin6_family = AF_INET6; + net_addr.sin6_flowinfo = comm->comm_id; + status = ucc_tl_mlx5_mcast_join_mcast_post(comm->ctx, &net_addr, 1); + if (status < 0) { + tl_error(comm->lib, "rank 0 is unable to join mcast group error %d", status); + tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_FAILED; + return UCC_INPROGRESS; + } + + comm->mcast_addr = net_addr; + tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_POST; + + return UCC_INPROGRESS; + } + + case TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_POST: + { + /* rank 0 has already called rdma_join_multicast() + * it is time to wait for the rdma event to confirm the join */ + status = ucc_tl_mlx5_mcast_join_mcast_test(comm->ctx, &comm->event, 1); + if (UCC_OK != status) { + if (status < 0) { + tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_FAILED; + } + if (comm->event) { + if (rdma_ack_cm_event(comm->event) < 0) { + tl_error(comm->lib, "rdma_ack_cm_event failed"); + return UCC_ERR_NO_RESOURCE; + } + comm->event = NULL; + } + return UCC_INPROGRESS; + } + + ucc_assert(comm->event != NULL); + + /* at this point, rank 0 has joined mcast group */ + tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_READY; + + return UCC_INPROGRESS; + } + + case TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_READY: + case TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_FAILED: + { + + data = ucc_calloc(1, sizeof(ucc_tl_mlx5_mcast_join_info_t), + "ucc_tl_mlx5_mcast_join_info_t"); + if (!data) { + tl_error(comm->lib, "unable to allocate memory for group setup info"); + return UCC_ERR_NO_MEMORY; + } + + comm->group_setup_info = data; + + if (tl_team->mcast_state == TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_READY) { + /* rank 0 bcast the lid/gid to other processes */ + data->status = UCC_OK; + data->dgid = comm->event->param.ud.ah_attr.grh.dgid; + data->dlid = comm->event->param.ud.ah_attr.dlid; + comm->mcast_lid = data->dlid; + comm->mgid = data->dgid; + } else { + /* rank 0 bcast the failed status to other processes so others do not hang */ + data->status = UCC_ERR_NO_RESOURCE; + } + + status = comm->bcast_post(comm->p2p_ctx, data, sizeof(ucc_tl_mlx5_mcast_join_info_t), + 0, &comm->group_setup_info_req); + if (UCC_OK != status) { + tl_error(comm->lib, "unable to post bcast for group setup info"); + ucc_free(comm->group_setup_info); + if (comm->event) { + if (rdma_ack_cm_event(comm->event) < 0) { + tl_error(comm->lib, "rdma_ack_cm_event failed"); + return UCC_ERR_NO_RESOURCE; + } + comm->event = NULL; + } + return status; + } + + tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_GRP_BCAST_POST; + + return UCC_INPROGRESS; + } + + case TL_MLX5_TEAM_STATE_MCAST_GRP_BCAST_POST: + { + /* rank 0 polls bcast request and wait for its completion */ + status = comm->bcast_test(comm->group_setup_info_req); + if (UCC_OK != status) { + /* bcast is not completed yet */ + if (status < 0) { + if (rdma_ack_cm_event(comm->event) < 0) { + tl_error(comm->lib, "rdma_ack_cm_event failed"); + } + ucc_free(comm->group_setup_info); + } + return status; + } + + if (comm->group_setup_info->status != UCC_OK) { + /* rank 0 was not able to join a mcast group so all + * the ranks should return */ + if (rdma_ack_cm_event(comm->event) < 0) { + tl_error(comm->lib, "rdma_ack_cm_event failed"); + } + ucc_free(comm->group_setup_info); + return UCC_ERR_NO_RESOURCE; + } + + ucc_free(comm->group_setup_info); + if (comm->event) { + if (rdma_ack_cm_event(comm->event) < 0) { + tl_error(comm->lib, "rdma_ack_cm_event failed"); + return UCC_ERR_NO_RESOURCE; + } + comm->event = NULL; + } + + /* setup of the rest of the mcast resources */ + status = ucc_tl_mlx5_mcast_coll_setup_comm_resources(comm); + if (UCC_OK != status) { + return status; + } + + tl_debug(comm->lib, "initialized tl mcast team: %p", tl_team); + tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_READY; + + return UCC_INPROGRESS; + } + + case TL_MLX5_TEAM_STATE_MCAST_READY: + case TL_MLX5_TEAM_STATE_MCAST_NOT_AVAILABLE: + { + return UCC_OK; + } + + default: + { + tl_error(comm->lib, "unknown state during mcast team: %p create", tl_team); + return UCC_ERR_NO_RESOURCE; + } + } + } else { + /* none rank 0 team create states */ + switch(tl_team->mcast_state) { + case TL_MLX5_TEAM_STATE_MCAST_INIT: + { + /* none 0 ranks bcast post to wait for rank 0 for lid/gid + * of the mcast group */ + data = ucc_calloc(1, sizeof(ucc_tl_mlx5_mcast_join_info_t), + "ucc_tl_mlx5_mcast_join_info_t"); + if (!data) { + tl_error(comm->lib, "unable to allocate memory for group setup info"); + return UCC_ERR_NO_MEMORY; + } + + status = comm->bcast_post(comm->p2p_ctx, data, sizeof(ucc_tl_mlx5_mcast_join_info_t), + 0, &comm->group_setup_info_req); + if (UCC_OK != status) { + tl_error(comm->lib, "unable to post bcast for group setup info"); + ucc_free(data); + return status; + } + + comm->group_setup_info = data; + tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_GRP_BCAST_POST; + + return UCC_INPROGRESS; + } + + case TL_MLX5_TEAM_STATE_MCAST_GRP_BCAST_POST: + { + /* none rank 0 processes poll bcast request and wait for its completion */ + status = comm->bcast_test(comm->group_setup_info_req); + if (UCC_OK != status) { + /* bcast is not completed yet */ + if (status < 0) { + ucc_free(comm->group_setup_info); + } + return status; + } + + data = comm->group_setup_info; + status = data->status; + if (UCC_OK != status) { + /* rank 0 was not able to join a mcast group so all + * the ranks should return */ + ucc_free(data); + return status; + } + + /* now it is time for none rank 0 to call rdma_join_multicast() */ + memcpy(&net_addr.sin6_addr, &(data->dgid), sizeof(struct in6_addr)); + net_addr.sin6_family = AF_INET6; + + status = ucc_tl_mlx5_mcast_join_mcast_post(comm->ctx, &net_addr, 0); + if (status < 0) { + tl_error(comm->lib, "none-root rank is unable to join mcast group error %d", status); + ucc_free(data); + return status; + } + + comm->mcast_addr = net_addr; + tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_POST; + + return UCC_INPROGRESS; + } + + case TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_POST: + { + /* none-root rank has already called rdma_join_multicast() + * it is time to wait for the rdma event to confirm the join */ + status = ucc_tl_mlx5_mcast_join_mcast_test(comm->ctx, &comm->event, 0); + if (UCC_OK != status) { + if (comm->event) { + if (rdma_ack_cm_event(comm->event) < 0) { + tl_error(comm->lib, "rdma_ack_cm_event failed"); + return UCC_ERR_NO_RESOURCE; + } + comm->event = NULL; + } + if (status < 0) { + ucc_free(comm->group_setup_info); + } + return status; + } + + ucc_assert(comm->event != NULL); + + comm->mcast_lid = comm->group_setup_info->dlid; + comm->mgid = comm->group_setup_info->dgid; + + ucc_free(comm->group_setup_info); + if (comm->event) { + if (rdma_ack_cm_event(comm->event) < 0) { + tl_error(comm->lib, "rdma_ack_cm_event failed"); + return UCC_ERR_NO_RESOURCE; + } + comm->event = NULL; + } + + /* at this point, none-root rank has joined mcast group */ + tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_READY; + + return UCC_INPROGRESS; + } + + case TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_READY: + { + /* setup of the rest of the mcast resources */ + status = ucc_tl_mlx5_mcast_coll_setup_comm_resources(comm); + if (UCC_OK != status) { + return status; + } + + tl_debug(comm->lib, "initialized tl mcast team: %p", tl_team); + tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_READY; + + return UCC_INPROGRESS; + } + + case TL_MLX5_TEAM_STATE_MCAST_READY: + case TL_MLX5_TEAM_STATE_MCAST_NOT_AVAILABLE: + { + return UCC_OK; + } + + default: + { + tl_error(comm->lib, "unknown state during mcast team: %p create", tl_team); + return UCC_ERR_NO_RESOURCE; + } + } + } +} diff --git a/src/components/tl/mlx5/tl_mlx5.c b/src/components/tl/mlx5/tl_mlx5.c index 0210f2302c..3e1f24f4ca 100644 --- a/src/components/tl/mlx5/tl_mlx5.c +++ b/src/components/tl/mlx5/tl_mlx5.c @@ -102,6 +102,10 @@ static ucc_config_field_t ucc_tl_mlx5_context_config_table[] = { ucc_offsetof(ucc_tl_mlx5_context_config_t, mcast_ctx_conf.timeout), UCC_CONFIG_TYPE_INT}, + {"MCAST_ENABLE", "0", "Enable Mcast", + ucc_offsetof(ucc_tl_mlx5_context_config_t, mcast_ctx_conf.mcast_enabled), + UCC_CONFIG_TYPE_INT}, + {"MCAST_NET_DEVICE", "", "Specifies which network device to use for Mcast", ucc_offsetof(ucc_tl_mlx5_context_config_t, mcast_ctx_conf.ib_dev_name), UCC_CONFIG_TYPE_STRING}, diff --git a/src/components/tl/mlx5/tl_mlx5.h b/src/components/tl/mlx5/tl_mlx5.h index 155e6144af..8dbe4ff408 100644 --- a/src/components/tl/mlx5/tl_mlx5.h +++ b/src/components/tl/mlx5/tl_mlx5.h @@ -106,27 +106,41 @@ typedef enum TL_MLX5_TEAM_STATE_INIT, TL_MLX5_TEAM_STATE_POSTED, TL_MLX5_TEAM_STATE_ALLTOALL_INIT, - TL_MLX5_TEAM_STATE_ALLTOALL_POSTED + TL_MLX5_TEAM_STATE_ALLTOALL_POSTED, + TL_MLX5_TEAM_STATE_ALLTOALL_READY, + TL_MLX5_TEAM_STATE_ALLTOALL_NOT_AVAILABLE } ucc_tl_mlx5_team_state_t; +typedef enum +{ + TL_MLX5_TEAM_STATE_MCAST_INIT, + TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_POST, + TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_READY, + TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_FAILED, + TL_MLX5_TEAM_STATE_MCAST_GRP_BCAST_POST, + TL_MLX5_TEAM_STATE_MCAST_READY, + TL_MLX5_TEAM_STATE_MCAST_NOT_AVAILABLE +} ucc_tl_mlx5_team_mcast_state_t; + typedef struct ucc_tl_mlx5_team_status { ucc_status_t local; ucc_status_t global; } ucc_tl_mlx5_team_status_t; typedef struct ucc_tl_mlx5_team { - ucc_tl_team_t super; - ucc_service_coll_req_t *scoll_req; - ucc_tl_mlx5_team_state_t state; - void *dm_offset; - ucc_mpool_t dm_pool; - struct ibv_dm *dm_ptr; - struct ibv_mr *dm_mr; - ucc_tl_mlx5_team_status_t a2a_status; - ucc_tl_mlx5_alltoall_t *a2a; - ucc_topo_t *topo; - ucc_ep_map_t ctx_map; - ucc_tl_mlx5_mcast_team_t *mcast; + ucc_tl_team_t super; + ucc_service_coll_req_t *scoll_req; + ucc_tl_mlx5_team_state_t a2a_state; + ucc_tl_mlx5_team_mcast_state_t mcast_state; + void *dm_offset; + ucc_mpool_t dm_pool; + struct ibv_dm *dm_ptr; + struct ibv_mr *dm_mr; + ucc_tl_mlx5_team_status_t a2a_status; + ucc_tl_mlx5_alltoall_t *a2a; + ucc_topo_t *topo; + ucc_ep_map_t ctx_map; + ucc_tl_mlx5_mcast_team_t *mcast; } ucc_tl_mlx5_team_t; UCC_CLASS_DECLARE(ucc_tl_mlx5_team_t, ucc_base_context_t *, const ucc_base_team_params_t *); diff --git a/src/components/tl/mlx5/tl_mlx5_coll.c b/src/components/tl/mlx5/tl_mlx5_coll.c index 90e224f9f6..861d4a4c67 100644 --- a/src/components/tl/mlx5/tl_mlx5_coll.c +++ b/src/components/tl/mlx5/tl_mlx5_coll.c @@ -45,7 +45,14 @@ ucc_status_t ucc_tl_mlx5_bcast_mcast_init(ucc_base_coll_args_t *coll_args, ucc_status_t ucc_tl_mlx5_task_finalize(ucc_coll_task_t *coll_task) { - ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t); + ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t); + ucc_tl_mlx5_mcast_coll_req_t *req = task->bcast_mcast.req_handle; + + if (req != NULL) { + ucc_assert(coll_task->status != UCC_INPROGRESS); + ucc_free(req); + task->bcast_mcast.req_handle = NULL; + } tl_trace(UCC_TASK_LIB(task), "finalizing task %p", task); ucc_tl_mlx5_put_task(task); diff --git a/src/components/tl/mlx5/tl_mlx5_coll.h b/src/components/tl/mlx5/tl_mlx5_coll.h index 642dd71581..eb441bdcdf 100644 --- a/src/components/tl/mlx5/tl_mlx5_coll.h +++ b/src/components/tl/mlx5/tl_mlx5_coll.h @@ -79,6 +79,7 @@ ucc_tl_mlx5_get_task(ucc_base_coll_args_t *coll_args, ucc_base_team_t *team) UCC_TL_MLX5_PROFILE_REQUEST_NEW(task, "tl_mlx5_task", 0); ucc_coll_task_init(&task->super, coll_args, team); + task->bcast_mcast.req_handle = NULL; return task; } diff --git a/src/components/tl/mlx5/tl_mlx5_context.c b/src/components/tl/mlx5/tl_mlx5_context.c index 5ac7b59f7d..8e6e764a22 100644 --- a/src/components/tl/mlx5/tl_mlx5_context.c +++ b/src/components/tl/mlx5/tl_mlx5_context.c @@ -51,14 +51,13 @@ UCC_CLASS_INIT_FUNC(ucc_tl_mlx5_context_t, status = ucc_tl_mlx5_mcast_context_init(&(self->mcast), &(self->cfg.mcast_ctx_conf)); if (UCC_OK != status) { + self->mcast.mcast_ready = 0; tl_debug(self->super.super.lib, "failed to initialize mcast context"); - goto err_mcast_context; + } else { + self->mcast.mcast_ready = 1; } - return UCC_OK; -err_mcast_context: - ucc_rcache_destroy(self->rcache); err_rcache: ucc_mpool_cleanup(&self->req_mp, 1); return status; @@ -75,11 +74,15 @@ UCC_CLASS_CLEANUP_FUNC(ucc_tl_mlx5_context_t) tl_debug(self->super.super.lib, "failed to free ib ctx and pd"); }; - if (!self->sock) { + if (self->sock) { close(self->sock); } ucc_mpool_cleanup(&self->req_mp, 1); + + if (self->mcast.mcast_ready) { + ucc_tl_mlx5_mcast_clean_ctx(&self->mcast.mcast_context); + } } UCC_CLASS_DEFINE(ucc_tl_mlx5_context_t, ucc_tl_context_t); diff --git a/src/components/tl/mlx5/tl_mlx5_rcache.c b/src/components/tl/mlx5/tl_mlx5_rcache.c index d6f2aa47d8..630a882f92 100644 --- a/src/components/tl/mlx5/tl_mlx5_rcache.c +++ b/src/components/tl/mlx5/tl_mlx5_rcache.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -52,7 +52,7 @@ static void ucc_tl_mlx5_rcache_dump_region_cb(void *context, //NOLINT snprintf(buf, max, "bar ptr:%p", mlx5_rregion->reg.mr); } -static ucc_rcache_ops_t ucc_rcache_ops = { +static ucc_rcache_ops_t ucc_tl_mlx5_rcache_ops = { .mem_reg = rcache_reg_mr, .mem_dereg = rcache_dereg_mr, .dump_region = ucc_tl_mlx5_rcache_dump_region_cb @@ -60,14 +60,14 @@ static ucc_rcache_ops_t ucc_rcache_ops = { ucc_status_t tl_mlx5_rcache_create(ucc_tl_mlx5_context_t *ctx) { - ucc_rcache_params_t rcache_params; + ucc_rcache_params_t rcache_params; + ucc_rcache_set_default_params(&rcache_params); rcache_params.region_struct_size = sizeof(ucc_tl_mlx5_rcache_region_t); - rcache_params.ucm_event_priority = 1000; - rcache_params.context = (void *)ctx; - rcache_params.ops = &ucc_rcache_ops; - rcache_params.ucm_events = UCM_EVENT_VM_UNMAPPED - | UCM_EVENT_MEM_TYPE_FREE; + rcache_params.context = ctx; + rcache_params.ops = &ucc_tl_mlx5_rcache_ops; + rcache_params.ucm_events = UCM_EVENT_VM_UNMAPPED | + UCM_EVENT_MEM_TYPE_FREE; - return ucc_rcache_create(&rcache_params, "MLX5", &ctx->rcache); + return ucc_rcache_create(&rcache_params, "MLX5_A2A", &ctx->rcache); } diff --git a/src/components/tl/mlx5/tl_mlx5_team.c b/src/components/tl/mlx5/tl_mlx5_team.c index b326166674..16c85b54b9 100644 --- a/src/components/tl/mlx5/tl_mlx5_team.c +++ b/src/components/tl/mlx5/tl_mlx5_team.c @@ -12,6 +12,7 @@ #include "core/ucc_team.h" #include #include "mcast/tl_mlx5_mcast.h" +#include "mcast/tl_mlx5_mcast_helper.h" static ucc_status_t ucc_tl_mlx5_topo_init(ucc_tl_mlx5_team_t *team) { @@ -66,19 +67,22 @@ UCC_CLASS_INIT_FUNC(ucc_tl_mlx5_team_t, ucc_base_context_t *tl_context, } self->a2a = NULL; - status = ucc_tl_mlx5_team_init_alltoall(self); + status = ucc_tl_mlx5_team_init_alltoall(self); if (UCC_OK != status) { return status; } - self->mcast = NULL; - status = ucc_tl_mlx5_mcast_team_init(tl_context, &(self->mcast), &(ctx->mcast), params, - &(UCC_TL_MLX5_TEAM_LIB(self)->cfg.mcast_conf)); + self->mcast = NULL; + status = ucc_tl_mlx5_mcast_team_init(tl_context, &(self->mcast), &(ctx->mcast), params, + &(UCC_TL_MLX5_TEAM_LIB(self)->cfg.mcast_conf)); if (UCC_OK != status) { - return status; + self->mcast_state = TL_MLX5_TEAM_STATE_MCAST_NOT_AVAILABLE; + } else { + self->mcast_state = TL_MLX5_TEAM_STATE_MCAST_INIT; } - self->state = TL_MLX5_TEAM_STATE_INIT; + self->a2a_state = TL_MLX5_TEAM_STATE_INIT; + tl_debug(tl_context->lib, "posted tl team: %p", self); return UCC_OK; } @@ -90,6 +94,9 @@ UCC_CLASS_CLEANUP_FUNC(ucc_tl_mlx5_team_t) ucc_tl_mlx5_dm_cleanup(self); ucc_tl_mlx5_alltoall_cleanup(self); ucc_tl_mlx5_topo_cleanup(self); + if (self->mcast_state != TL_MLX5_TEAM_STATE_MCAST_NOT_AVAILABLE) { + ucc_tl_mlx5_clean_mcast_comm(self->mcast->mcast_comm); + } } UCC_CLASS_DEFINE_DELETE_FUNC(ucc_tl_mlx5_team_t, ucc_base_team_t); @@ -101,15 +108,16 @@ ucc_status_t ucc_tl_mlx5_team_destroy(ucc_base_team_t *tl_team) return UCC_OK; } -ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team) +static inline ucc_status_t ucc_tl_mlx5_a2a_team_test(ucc_base_team_t *team) { ucc_tl_mlx5_team_t *tl_team = ucc_derived_of(team, ucc_tl_mlx5_team_t); ucc_team_t *core_team = UCC_TL_CORE_TEAM(tl_team); ucc_subset_t subset = {.map = UCC_TL_TEAM_MAP(tl_team), .myrank = UCC_TL_TEAM_RANK(tl_team)}; + ucc_status_t status = UCC_OK; - switch (tl_team->state) { + switch (tl_team->a2a_state) { case TL_MLX5_TEAM_STATE_INIT: status = ucc_service_allreduce( core_team, &tl_team->a2a_status.local, &tl_team->a2a_status.global, @@ -119,7 +127,7 @@ ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team) "failed to collect global status"); return status; } - tl_team->state = TL_MLX5_TEAM_STATE_POSTED; + tl_team->a2a_state = TL_MLX5_TEAM_STATE_POSTED; case TL_MLX5_TEAM_STATE_POSTED: status = ucc_service_coll_test(tl_team->scoll_req); if (status < 0) { @@ -132,11 +140,11 @@ ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team) return status; } ucc_service_coll_finalize(tl_team->scoll_req); - tl_team->state = TL_MLX5_TEAM_STATE_ALLTOALL_INIT; + tl_team->a2a_state = TL_MLX5_TEAM_STATE_ALLTOALL_INIT; case TL_MLX5_TEAM_STATE_ALLTOALL_INIT: tl_team->a2a_status.local = ucc_tl_mlx5_team_test_alltoall_start(tl_team); - tl_team->state = TL_MLX5_TEAM_STATE_ALLTOALL_POSTED; + tl_team->a2a_state = TL_MLX5_TEAM_STATE_ALLTOALL_POSTED; case TL_MLX5_TEAM_STATE_ALLTOALL_POSTED: // coverity[deref_arg:FALSE] tl_team->a2a_status.local = @@ -148,9 +156,52 @@ ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team) tl_debug(UCC_TL_TEAM_LIB(tl_team), "failed to init a2a: %s", ucc_status_string(tl_team->a2a_status.local)); } + tl_team->a2a_state = TL_MLX5_TEAM_STATE_ALLTOALL_READY; + tl_debug(team->context->lib, "initialized tl a2a team: %p", tl_team); + case TL_MLX5_TEAM_STATE_ALLTOALL_READY: + case TL_MLX5_TEAM_STATE_ALLTOALL_NOT_AVAILABLE: + return UCC_OK; + default: + tl_error(team->context->lib, "unknown state during a2a team: %p create", tl_team); + return UCC_ERR_NO_RESOURCE; + } +} + +ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team) +{ + ucc_tl_mlx5_team_t *tl_team = ucc_derived_of(team, ucc_tl_mlx5_team_t); + ucc_status_t a2a_status = UCC_OK; + ucc_status_t mcast_status = UCC_OK; + + a2a_status = ucc_tl_mlx5_a2a_team_test(team); + if (a2a_status < 0) { + tl_error(team->context->lib, "ALLTOALL tl team: %p creation failed %d", + team, a2a_status); + tl_team->a2a_state = TL_MLX5_TEAM_STATE_ALLTOALL_NOT_AVAILABLE; + } + + if (tl_team->mcast_state != TL_MLX5_TEAM_STATE_MCAST_NOT_AVAILABLE) { + mcast_status = ucc_tl_mlx5_mcast_team_test(team); + if (mcast_status < 0) { + tl_error(team->context->lib, "MCAST tl team: %p creation failed %d", + team, mcast_status); + tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_NOT_AVAILABLE; + } + } + + if (tl_team->mcast_state == TL_MLX5_TEAM_STATE_MCAST_NOT_AVAILABLE && + tl_team->a2a_state == TL_MLX5_TEAM_STATE_ALLTOALL_NOT_AVAILABLE) { + tl_error(team->context->lib, "unable to initialize tl team: %p", team); + return UCC_ERR_NO_RESOURCE; + } + + if (UCC_OK != a2a_status || UCC_OK != mcast_status) { + return UCC_INPROGRESS; } - tl_debug(team->context->lib, "initialized tl team: %p", tl_team); + tl_debug(team->context->lib, "initialized tl team: %p: MCAST component is %s ALLTOALL component is %s", + team, (tl_team->mcast_state == TL_MLX5_TEAM_STATE_MCAST_READY)?"ENABLED":"DISABLED", + (tl_team->a2a_state == TL_MLX5_TEAM_STATE_ALLTOALL_READY)?"ENABLED":"DISABLED"); return UCC_OK; } diff --git a/src/components/tl/sharp/tl_sharp_context.c b/src/components/tl/sharp/tl_sharp_context.c index 72461066b3..5c3140bfa9 100644 --- a/src/components/tl/sharp/tl_sharp_context.c +++ b/src/components/tl/sharp/tl_sharp_context.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -269,15 +269,12 @@ ucc_status_t ucc_tl_sharp_rcache_create(struct sharp_coll_context *context, { ucc_rcache_params_t rcache_params; - rcache_params.ucm_event_priority = 1000; - rcache_params.max_regions = ULONG_MAX; - rcache_params.max_size = SIZE_MAX; + ucc_rcache_set_default_params(&rcache_params); rcache_params.region_struct_size = sizeof(ucc_tl_sharp_rcache_region_t); - rcache_params.ucm_events = UCM_EVENT_VM_UNMAPPED | - UCM_EVENT_MEM_TYPE_FREE; rcache_params.context = context; rcache_params.ops = &ucc_tl_sharp_rcache_ops; - rcache_params.flags = 0; + rcache_params.ucm_events = UCM_EVENT_VM_UNMAPPED | + UCM_EVENT_MEM_TYPE_FREE; return ucc_rcache_create(&rcache_params, "SHARP", rcache); } diff --git a/src/components/tl/ucp/Makefile.am b/src/components/tl/ucp/Makefile.am index bf8e40aa6c..6074ed65c8 100644 --- a/src/components/tl/ucp/Makefile.am +++ b/src/components/tl/ucp/Makefile.am @@ -1,5 +1,5 @@ # -# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # if TL_UCP_ENABLED @@ -14,6 +14,7 @@ allgather = \ allgather/allgather.c \ allgather/allgather_ring.c \ allgather/allgather_neighbor.c \ + allgather/allgather_bruck.c \ allgather/allgather_knomial.c allgatherv = \ diff --git a/src/components/tl/ucp/allgather/allgather.c b/src/components/tl/ucp/allgather/allgather.c index 926b732e55..6900944e66 100644 --- a/src/components/tl/ucp/allgather/allgather.c +++ b/src/components/tl/ucp/allgather/allgather.c @@ -23,6 +23,10 @@ ucc_base_coll_alg_info_t {.id = UCC_TL_UCP_ALLGATHER_ALG_NEIGHBOR, .name = "neighbor", .desc = "O(N) Neighbor Exchange N/2 steps"}, + [UCC_TL_UCP_ALLGATHER_ALG_BRUCK] = + {.id = UCC_TL_UCP_ALLGATHER_ALG_BRUCK, + .name = "bruck", + .desc = "O(log(N)) Variation of Bruck algorithm"}, [UCC_TL_UCP_ALLGATHER_ALG_LAST] = { .id = 0, .name = NULL, .desc = NULL}}; diff --git a/src/components/tl/ucp/allgather/allgather.h b/src/components/tl/ucp/allgather/allgather.h index b68ab00e95..ac3592df86 100644 --- a/src/components/tl/ucp/allgather/allgather.h +++ b/src/components/tl/ucp/allgather/allgather.h @@ -12,6 +12,7 @@ enum { UCC_TL_UCP_ALLGATHER_ALG_KNOMIAL, UCC_TL_UCP_ALLGATHER_ALG_RING, UCC_TL_UCP_ALLGATHER_ALG_NEIGHBOR, + UCC_TL_UCP_ALLGATHER_ALG_BRUCK, UCC_TL_UCP_ALLGATHER_ALG_LAST }; @@ -56,6 +57,17 @@ void ucc_tl_ucp_allgather_neighbor_progress(ucc_coll_task_t *task); ucc_status_t ucc_tl_ucp_allgather_neighbor_start(ucc_coll_task_t *task); +/* Bruck */ +ucc_status_t ucc_tl_ucp_allgather_bruck_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team, + ucc_coll_task_t **task_h); + +void ucc_tl_ucp_allgather_bruck_progress(ucc_coll_task_t *task); + +ucc_status_t ucc_tl_ucp_allgather_bruck_start(ucc_coll_task_t *task); + +ucc_status_t ucc_tl_ucp_allgather_bruck_finalize(ucc_coll_task_t *coll_task); + /* Uses allgather_kn_radix from config */ ucc_status_t ucc_tl_ucp_allgather_knomial_init(ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, diff --git a/src/components/tl/ucp/allgather/allgather_bruck.c b/src/components/tl/ucp/allgather/allgather_bruck.c new file mode 100644 index 0000000000..7b831eb2ad --- /dev/null +++ b/src/components/tl/ucp/allgather/allgather_bruck.c @@ -0,0 +1,258 @@ +/** + * Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ +#include "config.h" +#include "tl_ucp.h" +#include "allgather.h" +#include "core/ucc_progress_queue.h" +#include "tl_ucp_sendrecv.h" +#include "utils/ucc_math.h" +#include "utils/ucc_coll_utils.h" +#include "components/mc/ucc_mc.h" +#include + +ucc_status_t ucc_tl_ucp_allgather_bruck_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team, + ucc_coll_task_t **task_h) +{ + ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t); + ucc_tl_ucp_task_t *task = ucc_tl_ucp_init_task(coll_args, team); + ucc_status_t status = UCC_OK; + ucc_rank_t trank = UCC_TL_TEAM_RANK(tl_team); + ucc_rank_t tsize = UCC_TL_TEAM_SIZE(tl_team); + ucc_memory_type_t rmem = TASK_ARGS(task).dst.info.mem_type; + ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype; + size_t count = TASK_ARGS(task).dst.info.count; + size_t data_size = (count / tsize) * ucc_dt_size(dt); + size_t scratch_size = (tsize - trank) * data_size; + + if (!ucc_coll_args_is_predefined_dt(&TASK_ARGS(task), UCC_RANK_INVALID)) { + tl_error(UCC_TASK_LIB(task), "user defined datatype is not supported"); + status = UCC_ERR_NOT_SUPPORTED; + goto out; + } + tl_trace(UCC_TASK_LIB(task), "ucc_tl_ucp_allgather_bruck_init"); + + task->super.post = ucc_tl_ucp_allgather_bruck_start; + task->super.progress = ucc_tl_ucp_allgather_bruck_progress; + task->super.finalize = ucc_tl_ucp_allgather_bruck_finalize; + + /* allocate scratch buffer only on non root rank */ + if (trank != 0) { + if (UCC_MEMORY_TYPE_HOST != rmem) { + scratch_size = tsize * data_size; + } + status = ucc_mc_alloc(&task->allgather_bruck.scratch_header, + scratch_size, UCC_MEMORY_TYPE_HOST); + if (ucc_unlikely(status != UCC_OK)) { + tl_error(UCC_TASK_LIB(task), "failed to allocate scratch buffer"); + ucc_tl_ucp_coll_finalize(&task->super); + goto out; + } + task->allgather_bruck.scratch_size = scratch_size; + } else { + task->allgather_bruck.scratch_header = NULL; + task->allgather_bruck.scratch_size = 0; + } +out: + if (status != UCC_OK) { + ucc_tl_ucp_put_task(task); + return status; + } + + *task_h = &task->super; + return status; +} + +ucc_status_t ucc_tl_ucp_allgather_bruck_finalize(ucc_coll_task_t *coll_task) +{ + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + ucc_status_t global_status = UCC_OK; + ucc_status_t status; + + tl_trace(UCC_TASK_LIB(task), "ucc_tl_ucp_allgather_bruck_finalize"); + + if (task->allgather_bruck.scratch_header != NULL) { + /* deallocate scratch buffer */ + global_status = ucc_mc_free(task->allgather_bruck.scratch_header); + if (ucc_unlikely(global_status != UCC_OK)) { + tl_error(UCC_TASK_LIB(task), + "failed to free scratch buffer memory"); + } + task->allgather_bruck.scratch_size = 0; + } + + status = ucc_tl_ucp_coll_finalize(&task->super); + if (ucc_unlikely(status != UCC_OK)) { + tl_error(UCC_TASK_LIB(task), + "failed to finalize allgather bruck collective"); + global_status = status; + } + return global_status; +} + +/* Inspired by implementation: https://github.com/open-mpi/ompi/blob/main/ompi/mca/coll/base/coll_base_allgather.c */ +void ucc_tl_ucp_allgather_bruck_progress(ucc_coll_task_t *coll_task) +{ + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + ucc_tl_ucp_team_t *team = TASK_TEAM(task); + ucc_rank_t trank = UCC_TL_TEAM_RANK(team); + ucc_rank_t tsize = UCC_TL_TEAM_SIZE(team); + void *rbuf = TASK_ARGS(task).dst.info.buffer; + ucc_memory_type_t rmem = TASK_ARGS(task).dst.info.mem_type; + ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype; + size_t count = TASK_ARGS(task).dst.info.count; + ucc_mc_buffer_header_t *scratch_header = + task->allgather_bruck.scratch_header; + size_t scratch_size = task->allgather_bruck.scratch_size; + size_t data_size = (count / tsize) * ucc_dt_size(dt); + ucc_rank_t recvfrom, sendto; + ucc_status_t status; + size_t blockcount, distance; + void *tmprecv, *tmpsend; + + if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) { + return; + } + + /* On each step doubles distance */ + distance = 1 << task->tagged.recv_posted; + tmpsend = rbuf; + while (distance < tsize) { + + recvfrom = (trank + distance) % tsize; + sendto = (trank + tsize - distance) % tsize; + + tmprecv = PTR_OFFSET(tmpsend, distance * data_size); + + if (distance <= tsize >> 1) { + blockcount = distance; + } else { + /* send-recv all reminder */ + blockcount = tsize - distance; + } + + /* Sendreceive */ + UCPCHECK_GOTO(ucc_tl_ucp_send_nb(tmpsend, blockcount * data_size, rmem, + sendto, team, task), + task, out); + UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(tmprecv, blockcount * data_size, rmem, + recvfrom, team, task), + task, out); + + if (UCC_INPROGRESS == ucc_tl_ucp_test_recv(task)) { + return; + } + + distance = 1 << task->tagged.recv_posted; + } + + if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) { + return; + } + + /* post processing step */ + if (trank != 0) { + if (UCC_MEMORY_TYPE_HOST == rmem) { + // copy blocks [0 .. (size - rank - 1)] from rbuf to shift buffer + status = ucc_mc_memcpy(scratch_header->addr, rbuf, scratch_size, + UCC_MEMORY_TYPE_HOST, rmem); + if (ucc_unlikely(status != UCC_OK)) { + tl_error(UCC_TASK_LIB(task), + "failed to copy data to scratch buffer"); + task->super.status = status; + return; + } + // move blocks [(size - rank) .. size] from rbuf to beginning of rbuf + // TODO: rewrite to cycle to get rid of overlap + memmove(rbuf, PTR_OFFSET(rbuf, scratch_size), trank * data_size); + // copy blocks from shift buffer starting at block [rank] in rbuf. + status = ucc_mc_memcpy(PTR_OFFSET(rbuf, trank * data_size), + scratch_header->addr, scratch_size, rmem, + UCC_MEMORY_TYPE_HOST); + if (ucc_unlikely(status != UCC_OK)) { + tl_error(UCC_TASK_LIB(task), + "failed to copy data from scratch to rbuff buffer"); + task->super.status = status; + return; + } + } else { + /* In case of non host memory we perform two copy to host buffer and then back to device, 3 memcopy in total */ + /* TODO: replace with generic kernel to do bruck post step in sinle launch on device */ + status = ucc_mc_memcpy( + PTR_OFFSET(scratch_header->addr, trank * data_size), rbuf, + (tsize - trank) * data_size, UCC_MEMORY_TYPE_HOST, rmem); + if (ucc_unlikely(status != UCC_OK)) { + tl_error(UCC_TASK_LIB(task), + "failed to copy first data part to scratch buffer"); + task->super.status = status; + return; + } + status = + ucc_mc_memcpy(scratch_header->addr, + PTR_OFFSET(rbuf, (tsize - trank) * data_size), + trank * data_size, UCC_MEMORY_TYPE_HOST, rmem); + if (ucc_unlikely(status != UCC_OK)) { + tl_error(UCC_TASK_LIB(task), + "failed to copy second data part to scratch buffer"); + task->super.status = status; + return; + } + status = + ucc_mc_memcpy(rbuf, scratch_header->addr, tsize * data_size, + rmem, UCC_MEMORY_TYPE_HOST); + if (ucc_unlikely(status != UCC_OK)) { + tl_error(UCC_TASK_LIB(task), + "failed to copy from scratch buffer to dst"); + task->super.status = status; + return; + } + } + } + + ucc_assert(UCC_TL_UCP_TASK_P2P_COMPLETE(task)); + + task->super.status = UCC_OK; + +out: + UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_bruck_done", 0); +} + +ucc_status_t ucc_tl_ucp_allgather_bruck_start(ucc_coll_task_t *coll_task) +{ + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + ucc_tl_ucp_team_t *team = TASK_TEAM(task); + size_t count = TASK_ARGS(task).dst.info.count; + void *sbuf = TASK_ARGS(task).src.info.buffer; + void *rbuf = TASK_ARGS(task).dst.info.buffer; + ucc_memory_type_t smem = TASK_ARGS(task).src.info.mem_type; + ucc_memory_type_t rmem = TASK_ARGS(task).dst.info.mem_type; + ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype; + ucc_rank_t trank = UCC_TL_TEAM_RANK(team); + ucc_rank_t tsize = UCC_TL_TEAM_SIZE(team); + size_t data_size = (count / tsize) * ucc_dt_size(dt); + ucc_status_t status; + + UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_bruck_start", 0); + ucc_tl_ucp_task_reset(task, UCC_INPROGRESS); + + /* initial step: copy data on non root ranks to the beginning of buffer */ + if (!UCC_IS_INPLACE(TASK_ARGS(task))) { + // not inplace: copy chunk from source buff to beginning of receive + status = ucc_mc_memcpy(rbuf, sbuf, data_size, rmem, smem); + if (ucc_unlikely(UCC_OK != status)) { + return status; + } + } else if (trank != 0) { + // inplace: copy chunk to the begin + status = ucc_mc_memcpy(rbuf, PTR_OFFSET(rbuf, data_size * trank), + data_size, rmem, rmem); + if (ucc_unlikely(UCC_OK != status)) { + return status; + } + } + + return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); +} diff --git a/src/components/tl/ucp/allreduce/allreduce.c b/src/components/tl/ucp/allreduce/allreduce.c index 4cda3765b8..ef3f6f54ad 100644 --- a/src/components/tl/ucp/allreduce/allreduce.c +++ b/src/components/tl/ucp/allreduce/allreduce.c @@ -21,7 +21,7 @@ ucc_base_coll_alg_info_t .desc = "recursive knomial scatter-reduce followed by knomial " "allgather (optimized for BW)"}, [UCC_TL_UCP_ALLREDUCE_ALG_DBT] = - {.id = UCC_TL_UCP_ALLREDUCE_ALG_SRA_KNOMIAL, + {.id = UCC_TL_UCP_ALLREDUCE_ALG_DBT, .name = "dbt", .desc = "alreduce over double binary tree where a leaf in one tree " "will be intermediate in other (optimized for BW)"}, diff --git a/src/components/tl/ucp/alltoall/alltoall_pairwise.c b/src/components/tl/ucp/alltoall/alltoall_pairwise.c index 1233609cdd..029e346fb9 100644 --- a/src/components/tl/ucp/alltoall/alltoall_pairwise.c +++ b/src/components/tl/ucp/alltoall/alltoall_pairwise.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -11,6 +11,10 @@ #include "utils/ucc_math.h" #include "tl_ucp_sendrecv.h" +/* TODO: add as parameters */ +#define MSG_MEDIUM 66000 +#define NP_THRESH 32 + static inline ucc_rank_t get_recv_peer(ucc_rank_t rank, ucc_rank_t size, ucc_rank_t step) { @@ -23,6 +27,29 @@ static inline ucc_rank_t get_send_peer(ucc_rank_t rank, ucc_rank_t size, return (rank - step + size) % size; } +static ucc_rank_t get_num_posts(const ucc_tl_ucp_team_t *team, + const ucc_coll_args_t *args) +{ + unsigned long posts = UCC_TL_UCP_TEAM_LIB(team)->cfg.alltoall_pairwise_num_posts; + ucc_rank_t tsize = UCC_TL_TEAM_SIZE(team); + size_t data_size; + + data_size = (size_t)args->src.info.count * + ucc_dt_size(args->src.info.datatype); + if (posts == UCC_ULUNITS_AUTO) { + if ((data_size > MSG_MEDIUM) && (tsize > NP_THRESH)) { + /* use pairwise algorithm */ + posts = 1; + } else { + /* use linear algorithm */ + posts = 0; + } + } + + posts = (posts > tsize || posts == 0) ? tsize: posts; + return posts; +} + void ucc_tl_ucp_alltoall_pairwise_progress(ucc_coll_task_t *coll_task) { ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); @@ -34,12 +61,10 @@ void ucc_tl_ucp_alltoall_pairwise_progress(ucc_coll_task_t *coll_task) ucc_rank_t grank = UCC_TL_TEAM_RANK(team); ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team); int polls = 0; - ucc_rank_t peer; - int posts, nreqs; + ucc_rank_t peer, nreqs; size_t data_size; - posts = UCC_TL_UCP_TEAM_LIB(team)->cfg.alltoall_pairwise_num_posts; - nreqs = (posts > gsize || posts == 0) ? gsize : posts; + nreqs = get_num_posts(team, &TASK_ARGS(task)); data_size = (size_t)(TASK_ARGS(task).src.info.count / gsize) * ucc_dt_size(TASK_ARGS(task).src.info.datatype); while ((task->tagged.send_posted < gsize || diff --git a/src/components/tl/ucp/alltoallv/alltoallv_pairwise.c b/src/components/tl/ucp/alltoallv/alltoallv_pairwise.c index 648b79655c..2082a8feba 100644 --- a/src/components/tl/ucp/alltoallv/alltoallv_pairwise.c +++ b/src/components/tl/ucp/alltoallv/alltoallv_pairwise.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -12,6 +12,9 @@ #include "utils/ucc_coll_utils.h" #include "tl_ucp_sendrecv.h" +/* TODO: add as parameters */ +#define NP_THRESH 32 + static inline ucc_rank_t get_recv_peer(ucc_rank_t rank, ucc_rank_t size, ucc_rank_t step) { @@ -24,6 +27,25 @@ static inline ucc_rank_t get_send_peer(ucc_rank_t rank, ucc_rank_t size, return (rank - step + size) % size; } +static ucc_rank_t get_num_posts(const ucc_tl_ucp_team_t *team) +{ + unsigned long posts = UCC_TL_UCP_TEAM_LIB(team)->cfg.alltoallv_pairwise_num_posts; + ucc_rank_t tsize = UCC_TL_TEAM_SIZE(team); + + if (posts == UCC_ULUNITS_AUTO) { + if (UCC_TL_TEAM_SIZE(team) <= NP_THRESH) { + /* use linear algorithm */ + posts = 0; + } else { + /* use pairwise algorithm */ + posts = 1; + } + } + + posts = (posts > tsize || posts == 0) ? tsize: posts; + return posts; +} + static void ucc_tl_ucp_alltoallv_pairwise_progress(ucc_coll_task_t *coll_task) { ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); @@ -35,12 +57,10 @@ static void ucc_tl_ucp_alltoallv_pairwise_progress(ucc_coll_task_t *coll_task) ucc_rank_t grank = UCC_TL_TEAM_RANK(team); ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team); int polls = 0; - ucc_rank_t peer; - int posts, nreqs; + ucc_rank_t peer, nreqs; size_t rdt_size, sdt_size, data_size, data_displ; - posts = UCC_TL_UCP_TEAM_LIB(team)->cfg.alltoallv_pairwise_num_posts; - nreqs = (posts > gsize || posts == 0) ? gsize : posts; + nreqs = get_num_posts(team); rdt_size = ucc_dt_size(TASK_ARGS(task).dst.info_v.datatype); sdt_size = ucc_dt_size(TASK_ARGS(task).src.info_v.datatype); while ((task->tagged.send_posted < gsize || diff --git a/src/components/tl/ucp/bcast/bcast_dbt.c b/src/components/tl/ucp/bcast/bcast_dbt.c index 4e1f77594f..36394edc57 100644 --- a/src/components/tl/ucp/bcast/bcast_dbt.c +++ b/src/components/tl/ucp/bcast/bcast_dbt.c @@ -107,6 +107,8 @@ void ucc_tl_ucp_bcast_dbt_progress(ucc_coll_task_t *coll_task) task->bcast_dbt.state = SEND_T1; SEND_T1: +/* test_recv is needed to progress ucp_worker */ + ucc_tl_ucp_test_recv(task); if ((coll_root == rank) || (task->bcast_dbt.t1.recv > 0)) { for (i = 0; i < 2; i++) { if ((t1.children[i] != UCC_RANK_INVALID) && @@ -122,6 +124,8 @@ void ucc_tl_ucp_bcast_dbt_progress(ucc_coll_task_t *coll_task) task->bcast_dbt.state = SEND_T2; SEND_T2: +/* test_recv is needed to progress ucp_worker */ + ucc_tl_ucp_test_recv(task); if ((coll_root == rank) || (task->bcast_dbt.t2.recv > 0)) { for (i = 0; i < 2; i++) { if ((t2.children[i] != UCC_RANK_INVALID) && @@ -231,6 +235,7 @@ ucc_status_t ucc_tl_ucp_bcast_dbt_init( task->super.post = ucc_tl_ucp_bcast_dbt_start; task->super.progress = ucc_tl_ucp_bcast_dbt_progress; task->super.finalize = ucc_tl_ucp_bcast_dbt_finalize; + task->n_polls = ucc_max(1, task->n_polls); tl_team = TASK_TEAM(task); rank = UCC_TL_TEAM_RANK(tl_team); size = UCC_TL_TEAM_SIZE(tl_team); diff --git a/src/components/tl/ucp/reduce/reduce_dbt.c b/src/components/tl/ucp/reduce/reduce_dbt.c index 08e8774974..08b8a7aed5 100644 --- a/src/components/tl/ucp/reduce/reduce_dbt.c +++ b/src/components/tl/ucp/reduce/reduce_dbt.c @@ -155,6 +155,8 @@ void ucc_tl_ucp_reduce_dbt_progress(ucc_coll_task_t *coll_task) task->reduce_dbt.state = REDUCE; REDUCE: +/* test_recv is needed to progress ucp_worker */ + ucc_tl_ucp_test_recv(task); for (i = 0; i < 2; i++) { if (trees[i].recv == trees[i].n_children && !task->reduce_dbt.reduction_comp[i]) { @@ -216,6 +218,8 @@ void ucc_tl_ucp_reduce_dbt_progress(ucc_coll_task_t *coll_task) } TEST_ROOT: +/* test_recv is needed to progress ucp_worker */ + ucc_tl_ucp_test_recv(task); if (UCC_INPROGRESS == ucc_tl_ucp_test_send(task) || task->reduce_dbt.reduction_comp[0] != trees[0].recv || task->reduce_dbt.reduction_comp[1] != trees[1].recv) { diff --git a/src/components/tl/ucp/tl_ucp.c b/src/components/tl/ucp/tl_ucp.c index 0b5973de85..72586dc5de 100644 --- a/src/components/tl/ucp/tl_ucp.c +++ b/src/components/tl/ucp/tl_ucp.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -36,17 +36,17 @@ ucc_config_field_t ucc_tl_ucp_lib_config_table[] = { {"", "", NULL, ucc_offsetof(ucc_tl_ucp_lib_config_t, super), UCC_CONFIG_TYPE_TABLE(ucc_tl_lib_config_table)}, - {"ALLTOALL_PAIRWISE_NUM_POSTS", "1", + {"ALLTOALL_PAIRWISE_NUM_POSTS", "auto", "Maximum number of outstanding send and receive messages in alltoall " "pairwise algorithm", ucc_offsetof(ucc_tl_ucp_lib_config_t, alltoall_pairwise_num_posts), - UCC_CONFIG_TYPE_UINT}, + UCC_CONFIG_TYPE_ULUNITS}, - {"ALLTOALLV_PAIRWISE_NUM_POSTS", "1", + {"ALLTOALLV_PAIRWISE_NUM_POSTS", "auto", "Maximum number of outstanding send and receive messages in alltoallv " "pairwise algorithm", ucc_offsetof(ucc_tl_ucp_lib_config_t, alltoallv_pairwise_num_posts), - UCC_CONFIG_TYPE_UINT}, + UCC_CONFIG_TYPE_ULUNITS}, /* TODO: add radix to config once it's fully supported by the algorithm {"ALLTOALLV_HYBRID_RADIX", "2", @@ -86,7 +86,7 @@ ucc_config_field_t ucc_tl_ucp_lib_config_table[] = { "other KN_RADIX values", ucc_offsetof(ucc_tl_ucp_lib_config_t, kn_radix), UCC_CONFIG_TYPE_UINT}, - {"BARRIER_KN_RADIX", "4", + {"BARRIER_KN_RADIX", "8", "Radix of the recursive-knomial barrier algorithm", ucc_offsetof(ucc_tl_ucp_lib_config_t, barrier_kn_radix), UCC_CONFIG_TYPE_UINT}, diff --git a/src/components/tl/ucp/tl_ucp.h b/src/components/tl/ucp/tl_ucp.h index 75fdc76de5..eac2303443 100644 --- a/src/components/tl/ucp/tl_ucp.h +++ b/src/components/tl/ucp/tl_ucp.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -60,8 +60,8 @@ typedef struct ucc_tl_ucp_lib_config { uint32_t scatter_kn_radix; ucc_on_off_auto_value_t scatter_kn_enable_recv_zcopy; uint32_t scatterv_linear_num_posts; - uint32_t alltoall_pairwise_num_posts; - uint32_t alltoallv_pairwise_num_posts; + unsigned long alltoall_pairwise_num_posts; + unsigned long alltoallv_pairwise_num_posts; ucc_pipeline_params_t allreduce_sra_kn_pipeline; int reduce_avg_pre_op; int reduce_scatter_ring_bidirectional; diff --git a/src/components/tl/ucp/tl_ucp_coll.c b/src/components/tl/ucp/tl_ucp_coll.c index 23c254b00e..3b4859b48f 100644 --- a/src/components/tl/ucp/tl_ucp_coll.c +++ b/src/components/tl/ucp/tl_ucp_coll.c @@ -262,6 +262,9 @@ ucc_status_t ucc_tl_ucp_alg_id_to_init(int alg_id, const char *alg_id_str, case UCC_TL_UCP_ALLGATHER_ALG_NEIGHBOR: *init = ucc_tl_ucp_allgather_neighbor_init; break; + case UCC_TL_UCP_ALLGATHER_ALG_BRUCK: + *init = ucc_tl_ucp_allgather_bruck_init; + break; default: status = UCC_ERR_INVALID_PARAM; break; diff --git a/src/components/tl/ucp/tl_ucp_coll.h b/src/components/tl/ucp/tl_ucp_coll.h index 6ab2c661dd..0a8a340955 100644 --- a/src/components/tl/ucp/tl_ucp_coll.h +++ b/src/components/tl/ucp/tl_ucp_coll.h @@ -181,6 +181,10 @@ typedef struct ucc_tl_ucp_task { ucc_rank_t tsize, int step); } allgather_ring; + struct { + ucc_mc_buffer_header_t *scratch_header; + size_t scratch_size; + } allgather_bruck; struct { ucc_rank_t dist; uint32_t radix; diff --git a/src/components/topo/ucc_sbgp.c b/src/components/topo/ucc_sbgp.c index e71b8a61ce..e0264ee9e2 100644 --- a/src/components/topo/ucc_sbgp.c +++ b/src/components/topo/ucc_sbgp.c @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -439,26 +439,82 @@ static int ucc_compare_proc_info_id(const void *a, const void *b) } else if (d1->numa_id != d2->numa_id) { return d1->numa_id - d2->numa_id; } else { - return d1->pid - d2->pid; + return 0; } } static ucc_status_t sbgp_create_full_ordered(ucc_topo_t *topo, ucc_sbgp_t *sbgp) { - ucc_rank_t gsize = ucc_subset_size(&topo->set); - proc_info_id_t *sorted; - ucc_rank_t i; + ucc_rank_t gsize = ucc_subset_size(&topo->set); + ucc_proc_info_t *pinfo = topo->topo->procs; + ucc_host_id_t *visited; + proc_info_id_t *sorted; + ucc_rank_t i, j, num_visited; + int is_sorted, d; ucc_assert(gsize > 0); sbgp->status = UCC_SBGP_ENABLED; sbgp->group_size = gsize; sbgp->group_rank = topo->set.myrank; + sbgp->rank_map = ucc_malloc(sizeof(ucc_rank_t) * gsize, "rank_map"); + if (ucc_unlikely(!sbgp->rank_map)) { + ucc_error("failed to allocate %zd bytes for rank_map", + gsize * sizeof(ucc_rank_t)); + return UCC_ERR_NO_MEMORY; + } + + visited = (ucc_host_id_t *)ucc_malloc(gsize * sizeof(ucc_host_id_t), + "visited host"); + if (ucc_unlikely(!visited)) { + ucc_error("failed to allocate %zd bytes for list of visited nodes", + gsize * sizeof(ucc_host_id_t)); + ucc_free(sbgp->rank_map); + return UCC_ERR_NO_MEMORY; + } + + is_sorted = 1; + num_visited = 1; + visited[0] = pinfo[0].host_hash; + for (i = 1; i < gsize; i++) { + if (pinfo[i].host_hash != pinfo[i-1].host_hash) { + /* check if we saw that host_has before*/ + for (j = 0; j < num_visited; j++) { + if (visited[j] == pinfo[i].host_hash) { + break; + } + } + if (j < num_visited) { + /* this host was present already, ranks are not ordered */ + is_sorted = 0; + break; + } + /* add new host to the list of visited */ + visited[num_visited++] = pinfo[i].host_hash; + } else { + d = ucc_compare_proc_info_id(&pinfo[i - 1].host_hash, + &pinfo[i].host_hash); + + if (d > 0) { + is_sorted = 0; + break; + } + } + } + ucc_free(visited); + + if (is_sorted) { + for (i = 0; i < gsize; i++) { + sbgp->rank_map[i] = i; + } + return UCC_OK; + } sorted = (proc_info_id_t *)ucc_malloc(gsize * sizeof(proc_info_id_t), "proc_sorted"); if (ucc_unlikely(!sorted)) { ucc_error("failed to allocate %zd bytes for sorted proc info", gsize * sizeof(proc_info_id_t)); + ucc_free(sbgp->rank_map); return UCC_ERR_NO_MEMORY; } @@ -467,14 +523,6 @@ static ucc_status_t sbgp_create_full_ordered(ucc_topo_t *topo, ucc_sbgp_t *sbgp) sorted[i].id = i; } - sbgp->rank_map = ucc_malloc(sizeof(ucc_rank_t) * gsize, "rank_map"); - if (ucc_unlikely(!sbgp->rank_map)) { - ucc_error("failed to allocate %zd bytes for rank_map", - gsize * sizeof(ucc_rank_t)); - ucc_free(sorted); - return UCC_ERR_NO_MEMORY; - } - qsort(sorted, gsize, sizeof(proc_info_id_t), ucc_compare_proc_info_id); for (i = 0; i < gsize; i++) { if (sorted[i].id == topo->set.myrank) { diff --git a/src/core/ucc_coll.c b/src/core/ucc_coll.c index 8cf3658570..6cb0426389 100644 --- a/src/core/ucc_coll.c +++ b/src/core/ucc_coll.c @@ -280,20 +280,19 @@ UCC_CORE_PROFILE_FUNC(ucc_status_t, ucc_collective_init, print_trace: *request = &task->super; - if (ucc_global_config.coll_trace.log_level >= UCC_LOG_LEVEL_DIAG) { + if (ucc_unlikely(ucc_global_config.coll_trace.log_level >= + UCC_LOG_LEVEL_DIAG)) { char coll_str[256]; ucc_coll_str(task, coll_str, sizeof(coll_str), ucc_global_config.coll_trace.log_level); - if (ucc_global_config.coll_trace.log_level <= UCC_LOG_LEVEL_DEBUG) { + if (ucc_global_config.coll_trace.log_level <= UCC_LOG_LEVEL_INFO) { if (team->rank == 0) { ucc_log_component_collective_trace( ucc_global_config.coll_trace.log_level, "coll_init: %s", coll_str); } } else { - ucc_log_component_collective_trace( - ucc_global_config.coll_trace.log_level, "coll_init: %s", - coll_str); + ucc_coll_trace_debug("coll_init: %s", coll_str); } } diff --git a/src/core/ucc_context.c b/src/core/ucc_context.c index 7c5fd3c3ca..13e51246b8 100644 --- a/src/core/ucc_context.c +++ b/src/core/ucc_context.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -45,6 +45,11 @@ static ucc_config_field_t ucc_context_config_table[] = { "is configured with OOB (global mode). 0 - disable, 1 - try, 2 - force.", ucc_offsetof(ucc_context_config_t, internal_oob), UCC_CONFIG_TYPE_UINT}, + {"THROTTLE_PROGRESS", "1000", + "Throttle UCC progress to every th invocation", + ucc_offsetof(ucc_context_config_t, throttle_progress), + UCC_CONFIG_TYPE_UINT}, + {NULL}}; UCC_CONFIG_REGISTER_TABLE(ucc_context_config_table, "UCC context", NULL, ucc_context_config_t, &ucc_config_global_list); @@ -614,9 +619,10 @@ ucc_status_t ucc_context_create_proc_info(ucc_lib_h lib, status = UCC_ERR_NO_MEMORY; goto error; } - ctx->rank = UCC_RANK_MAX; - ctx->lib = lib; - ctx->ids.pool_size = config->team_ids_pool_size; + ctx->throttle_progress = config->throttle_progress; + ctx->rank = UCC_RANK_MAX; + ctx->lib = lib; + ctx->ids.pool_size = config->team_ids_pool_size; ucc_list_head_init(&ctx->progress_list); ucc_copy_context_params(&ctx->params, params); ucc_copy_context_params(&b_params.params, params); @@ -903,7 +909,7 @@ ucc_status_t ucc_context_destroy(ucc_context_t *context) tl_ctx = context->tl_ctx[i]; tl_lib = ucc_derived_of(tl_ctx->super.lib, ucc_tl_lib_t); if (tl_ctx->ref_count != 0 ) { - ucc_warn("tl ctx %s is still in use", tl_lib->iface->super.name); + ucc_info("tl ctx %s is still in use", tl_lib->iface->super.name); } tl_lib->iface->context.destroy(&tl_ctx->super); } @@ -957,12 +963,25 @@ ucc_status_t ucc_context_progress_deregister(ucc_context_t *ctx, ucc_status_t ucc_context_progress(ucc_context_h context) { + static int call_num = 0; ucc_status_t status; ucc_context_progress_entry_t *entry; - /* progress registered progress fns */ - ucc_list_for_each(entry, &context->progress_list, list_elem) { - entry->fn(entry->arg); + int is_empty; + + is_empty = ucc_progress_queue_is_empty(context->pq); + if (ucc_likely(is_empty)) { + call_num--; + if (ucc_likely(call_num >= 0)) { + return UCC_OK; + } + /* progress registered progress fns */ + ucc_list_for_each(entry, &context->progress_list, list_elem) { + entry->fn(entry->arg); + } + call_num = context->throttle_progress; + return UCC_OK; } + /* the fn below returns int - number of completed tasks. TODO : do we need to handle it ? Maybe return to user as int as well? */ diff --git a/src/core/ucc_context.h b/src/core/ucc_context.h index a95dd2b920..3944d5675b 100644 --- a/src/core/ucc_context.h +++ b/src/core/ucc_context.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -77,6 +77,7 @@ typedef struct ucc_context { ucc_context_topo_t *topo; uint64_t cl_flags; ucc_tl_team_t *service_team; + int32_t throttle_progress; } ucc_context_t; typedef struct ucc_context_config { @@ -90,6 +91,7 @@ typedef struct ucc_context_config { uint32_t estimated_num_ppn; uint32_t lock_free_progress_q; uint32_t internal_oob; + uint32_t throttle_progress; } ucc_context_config_t; /* Internal function for context creation that takes explicit diff --git a/src/core/ucc_progress_queue.h b/src/core/ucc_progress_queue.h index ba3d20b297..d4ede0c8c3 100644 --- a/src/core/ucc_progress_queue.h +++ b/src/core/ucc_progress_queue.h @@ -1,5 +1,6 @@ /** - * Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * * See file LICENSE for terms. */ @@ -14,6 +15,7 @@ struct ucc_progress_queue { void (*enqueue)(ucc_progress_queue_t *pq, ucc_coll_task_t *task); void (*dequeue)(ucc_progress_queue_t *pq, ucc_coll_task_t **task); int (*progress)(ucc_progress_queue_t *pq); + int (*is_empty)(ucc_progress_queue_t *pq); void (*finalize)(ucc_progress_queue_t *pq); }; @@ -46,6 +48,11 @@ static inline int ucc_progress_queue(ucc_progress_queue_t *pq) return pq->progress(pq); } +static inline int ucc_progress_queue_is_empty(ucc_progress_queue_t *pq) +{ + return pq->is_empty(pq); +} + void ucc_progress_queue_finalize(ucc_progress_queue_t *pq); #endif diff --git a/src/core/ucc_progress_queue_mt.c b/src/core/ucc_progress_queue_mt.c index 466628e27c..7da2171f03 100644 --- a/src/core/ucc_progress_queue_mt.c +++ b/src/core/ucc_progress_queue_mt.c @@ -1,5 +1,6 @@ /** - * Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * * See file LICENSE for terms. */ @@ -25,7 +26,7 @@ typedef struct ucc_pq_mt_locked { } ucc_pq_mt_locked_t; static void ucc_pq_locked_mt_enqueue(ucc_progress_queue_t *pq, - ucc_coll_task_t * task) + ucc_coll_task_t *task) { ucc_pq_mt_locked_t *pq_mt = ucc_derived_of(pq, ucc_pq_mt_locked_t); @@ -42,7 +43,7 @@ static void ucc_pq_mt_enqueue(ucc_progress_queue_t *pq, ucc_coll_task_t *task) } static void ucc_pq_locked_mt_dequeue(ucc_progress_queue_t *pq, - ucc_coll_task_t ** popped_task) + ucc_coll_task_t **popped_task) { ucc_pq_mt_locked_t *pq_mt = ucc_derived_of(pq, ucc_pq_mt_locked_t); *popped_task = NULL; @@ -56,7 +57,7 @@ static void ucc_pq_locked_mt_dequeue(ucc_progress_queue_t *pq, } static void ucc_pq_mt_dequeue(ucc_progress_queue_t *pq, - ucc_coll_task_t ** popped_task) + ucc_coll_task_t **popped_task) { ucc_pq_mt_t *pq_mt = ucc_derived_of(pq, ucc_pq_mt_t); ucc_lf_queue_elem_t *elem = ucc_lf_queue_dequeue(&pq_mt->lf_queue, 1); @@ -100,6 +101,20 @@ static int ucc_pq_mt_progress(ucc_progress_queue_t *pq) return n_progressed; } +static int ucc_pq_locked_mt_is_empty(ucc_progress_queue_t *pq) +{ + ucc_pq_mt_locked_t *pq_mt = ucc_derived_of(pq, ucc_pq_mt_locked_t); + + /* this function should not be very accurate for the purpose of progress throttling */ + return ucc_list_is_empty(&pq_mt->queue); +} + +static int ucc_pq_mt_is_empty(ucc_progress_queue_t *pq) //NOLINT: pq is unused +{ + /* lock free progress queue never use throttling */ + return 0; +} + static void ucc_pq_locked_mt_finalize(ucc_progress_queue_t *pq) { ucc_pq_mt_locked_t *pq_mt = ucc_derived_of(pq, ucc_pq_mt_locked_t); @@ -128,6 +143,7 @@ ucc_status_t ucc_pq_mt_init(ucc_progress_queue_t **pq, pq_mt->super.dequeue = ucc_pq_mt_dequeue; pq_mt->super.progress = ucc_pq_mt_progress; pq_mt->super.finalize = ucc_pq_mt_finalize; + pq_mt->super.is_empty = ucc_pq_mt_is_empty; *pq = &pq_mt->super; } else { ucc_pq_mt_locked_t *pq_mt = ucc_malloc(sizeof(*pq_mt), "pq_mt"); @@ -141,6 +157,7 @@ ucc_status_t ucc_pq_mt_init(ucc_progress_queue_t **pq, pq_mt->super.dequeue = ucc_pq_locked_mt_dequeue; pq_mt->super.progress = ucc_pq_mt_progress; pq_mt->super.finalize = ucc_pq_locked_mt_finalize; + pq_mt->super.is_empty = ucc_pq_locked_mt_is_empty; *pq = &pq_mt->super; } return UCC_OK; diff --git a/src/core/ucc_progress_queue_st.c b/src/core/ucc_progress_queue_st.c index 048d7313dd..e9842a70d4 100644 --- a/src/core/ucc_progress_queue_st.c +++ b/src/core/ucc_progress_queue_st.c @@ -1,5 +1,6 @@ /** - * Copyright (c) 2020, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * * See file LICENSE for terms. */ @@ -67,6 +68,13 @@ static void ucc_pq_st_finalize(ucc_progress_queue_t *pq) ucc_free(pq_st); } +static int ucc_pq_st_is_empty(ucc_progress_queue_t *pq) +{ + ucc_pq_st_t *pq_st = ucc_derived_of(pq, ucc_pq_st_t); + + return ucc_list_is_empty(&pq_st->list); +} + ucc_status_t ucc_pq_st_init(ucc_progress_queue_t **pq) { ucc_pq_st_t *pq_st = ucc_malloc(sizeof(*pq_st), "pq_st"); @@ -79,6 +87,8 @@ ucc_status_t ucc_pq_st_init(ucc_progress_queue_t **pq) pq_st->super.dequeue = NULL; pq_st->super.progress = ucc_pq_st_progress; pq_st->super.finalize = ucc_pq_st_finalize; + pq_st->super.is_empty = ucc_pq_st_is_empty; + *pq = &pq_st->super; return UCC_OK; } diff --git a/src/ucc/api/ucc.h b/src/ucc/api/ucc.h index a269dfb940..9f8ee5f145 100644 --- a/src/ucc/api/ucc.h +++ b/src/ucc/api/ucc.h @@ -1,7 +1,7 @@ /** * @file ucc.h * @date 2020 - * @copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * @copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * @copyright Copyright (C) Huawei Technologies Co., Ltd. 2020. ALL RIGHTS RESERVED. * @copyright Copyright (C) UChicago Argonne, LLC. 2022. ALL RIGHTS RESERVED. * @@ -726,15 +726,6 @@ void ucc_get_version(unsigned *major_version, unsigned *minor_version, */ const char *ucc_get_version_string(void); -/** - * @ingroup UCC_LIB - * @brief Get UCC library version as a string. - * - * This routine returns the UCC library version as a string which consists of: - * "major.minor.release". - */ -const char *ucc_get_version_string(void); - /** * @ingroup UCC_LIB_INTERNAL diff --git a/src/utils/arch/aarch64/cpu.h b/src/utils/arch/aarch64/cpu.h index c160ab41c2..81d16bb3f2 100644 --- a/src/utils/arch/aarch64/cpu.h +++ b/src/utils/arch/aarch64/cpu.h @@ -1,5 +1,5 @@ /** -* Copyright (c) 2001-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +* Copyright (c) 2001-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (C) ARM Ltd. 2016-2020. ALL RIGHTS RESERVED. * Copyright (C) Stony Brook University. 2016-2020. ALL RIGHTS RESERVED. * @@ -47,11 +47,6 @@ typedef struct ucc_aarch64_cpuid { */ void ucc_aarch64_cpuid(ucc_aarch64_cpuid_t *cpuid); -static inline ucc_cpu_model_t ucc_arch_get_cpu_model() -{ - return UCC_CPU_MODEL_ARM_AARCH64; -} - static inline ucc_cpu_vendor_t ucc_arch_get_cpu_vendor() { ucc_aarch64_cpuid_t cpuid; @@ -61,7 +56,25 @@ static inline ucc_cpu_vendor_t ucc_arch_get_cpu_vendor() return UCC_CPU_VENDOR_FUJITSU_ARM; } + if ((cpuid.implementer == 0x41) && (cpuid.architecture == 8)) { + return UCC_CPU_VENDOR_NVIDIA; + } + return UCC_CPU_VENDOR_GENERIC_ARM; } +static inline ucc_cpu_model_t ucc_arch_get_cpu_model() +{ + ucc_aarch64_cpuid_t cpuid; + ucc_aarch64_cpuid(&cpuid); + + if ((ucc_arch_get_cpu_vendor() == UCC_CPU_VENDOR_NVIDIA) && + (cpuid.part == 0xd4f)) { + return UCC_CPU_MODEL_NVIDIA_GRACE; + } + + return UCC_CPU_MODEL_ARM_AARCH64; +} + + #endif diff --git a/src/utils/arch/cpu.h b/src/utils/arch/cpu.h index 17a74195b7..636e8b60b7 100644 --- a/src/utils/arch/cpu.h +++ b/src/utils/arch/cpu.h @@ -1,5 +1,5 @@ /** -* Copyright (c) NVIDIA CORPORATION & AFFILIATES, 2001-2023. ALL RIGHTS RESERVED. +* Copyright (c) NVIDIA CORPORATION & AFFILIATES, 2001-2024. ALL RIGHTS RESERVED. * Copyright (C) ARM Ltd. 2016. ALL RIGHTS RESERVED. * Copyright (C) Shanghai Zhaoxin Semiconductor Co., Ltd. 2020. ALL RIGHTS RESERVED. * Copyright (C) Rivos Inc. 2023 @@ -35,6 +35,7 @@ typedef enum ucc_cpu_model { UCC_CPU_MODEL_ZHAOXIN_ZHANGJIANG, UCC_CPU_MODEL_ZHAOXIN_WUDAOKOU, UCC_CPU_MODEL_ZHAOXIN_LUJIAZUI, + UCC_CPU_MODEL_NVIDIA_GRACE, UCC_CPU_MODEL_LAST } ucc_cpu_model_t; @@ -48,6 +49,7 @@ typedef enum ucc_cpu_vendor { UCC_CPU_VENDOR_GENERIC_RISCV, UCC_CPU_VENDOR_FUJITSU_ARM, UCC_CPU_VENDOR_ZHAOXIN, + UCC_CPU_VENDOR_NVIDIA, UCC_CPU_VENDOR_LAST } ucc_cpu_vendor_t; @@ -67,6 +69,8 @@ static inline ucc_cpu_vendor_t ucc_get_vendor_from_str(const char *v_name) return UCC_CPU_VENDOR_FUJITSU_ARM; if (strcasecmp(v_name, "zhaoxin") == 0) return UCC_CPU_VENDOR_ZHAOXIN; + if (strcasecmp(v_name, "nvidia") == 0) + return UCC_CPU_VENDOR_NVIDIA; return UCC_CPU_VENDOR_UNKNOWN; } @@ -102,6 +106,8 @@ static inline ucc_cpu_model_t ucc_get_model_from_str(const char *m_name) return UCC_CPU_MODEL_ZHAOXIN_WUDAOKOU; if (strcasecmp(m_name, "lujiazui") == 0) return UCC_CPU_MODEL_ZHAOXIN_LUJIAZUI; + if (strcasecmp(m_name, "grace") == 0) + return UCC_CPU_MODEL_NVIDIA_GRACE; return UCC_CPU_MODEL_UNKNOWN; } diff --git a/src/utils/ucc_coll_utils.c b/src/utils/ucc_coll_utils.c index 75a49400e2..533a9e4fb3 100644 --- a/src/utils/ucc_coll_utils.c +++ b/src/utils/ucc_coll_utils.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -56,7 +56,8 @@ static inline int ucc_coll_args_is_mem_symmetric(const ucc_coll_args_t *args, ucc_rank_t rank) { - ucc_rank_t root = args->root; + ucc_rank_t root = args->root; + if (UCC_IS_INPLACE(*args)) { return 1; } @@ -93,7 +94,7 @@ ucc_coll_args_is_mem_symmetric(const ucc_coll_args_t *args, return 0; } -int ucc_coll_args_is_predefined_dt(ucc_coll_args_t *args, ucc_rank_t rank) +int ucc_coll_args_is_predefined_dt(const ucc_coll_args_t *args, ucc_rank_t rank) { switch (args->coll_type) { case UCC_COLL_TYPE_BARRIER: @@ -160,7 +161,7 @@ int ucc_coll_args_is_predefined_dt(ucc_coll_args_t *args, ucc_rank_t rank) ucc_memory_type_t ucc_coll_args_mem_type(const ucc_coll_args_t *args, ucc_rank_t rank) { - ucc_rank_t root = args->root; + ucc_rank_t root = args->root; if (!ucc_coll_args_is_mem_symmetric(args, rank)) { return UCC_MEMORY_TYPE_ASYMMETRIC; diff --git a/src/utils/ucc_coll_utils.h b/src/utils/ucc_coll_utils.h index c5cb2ef392..5b24f19cf2 100644 --- a/src/utils/ucc_coll_utils.h +++ b/src/utils/ucc_coll_utils.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -332,6 +332,6 @@ static inline size_t ucc_buffer_block_offset_aligned(size_t total_count, @param [in] args pointer to the collective args. @param [in] rank rank to check, used only for rooted collective operations. */ -int ucc_coll_args_is_predefined_dt(ucc_coll_args_t *args, ucc_rank_t rank); +int ucc_coll_args_is_predefined_dt(const ucc_coll_args_t *args, ucc_rank_t rank); #endif diff --git a/src/utils/ucc_component.c b/src/utils/ucc_component.c index b19bd2e397..83d4aa8558 100644 --- a/src/utils/ucc_component.c +++ b/src/utils/ucc_component.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2020, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * See file LICENSE for terms. */ #include "config.h" @@ -52,7 +52,9 @@ static ucc_status_t ucc_component_load_one(const char *so_path, handle = dlopen(so_path, RTLD_LAZY); if (!handle) { - ucc_debug("failed to load UCC component library: %s", so_path); + error = dlerror(); + ucc_debug("failed to load UCC component library: %s (%s)", + so_path, error); goto error; } iface = (ucc_component_iface_t *)dlsym(handle, iface_struct); diff --git a/src/utils/ucc_log.h b/src/utils/ucc_log.h index b480ee55ae..d9aafab7ab 100644 --- a/src/utils/ucc_log.h +++ b/src/utils/ucc_log.h @@ -67,6 +67,26 @@ #define ucc_coll_trace_debug(_fmt, ...) \ ucc_log_component_collective_trace(UCS_LOG_LEVEL_DEBUG, _fmt, ##__VA_ARGS__) +/** + * Print a message regardless of current log level. Output can be + * enabled/disabled via environment variable/configuration settings. + * + * During debugging it can be useful to add a few prints to the code + * without changing a current log level. Also it is useful to be able + * to see messages only from specific processes. For example, one may + * want to see prints only from rank 0 when debugging MPI. + * + * The function is intended for debugging only. It should not be used + * in the real code. + */ + +#define ucc_print(_fmt, ...) \ + do { \ + ucs_log_dispatch(__FILE__, __LINE__, __FUNCTION__, \ + UCS_LOG_LEVEL_PRINT, \ + &ucc_global_config.log_component, \ + _fmt, ## __VA_ARGS__); \ + } while(0) static inline const char* ucc_coll_type_str(ucc_coll_type_t ct) { diff --git a/src/utils/ucc_parser.h b/src/utils/ucc_parser.h index 517dd88be8..90a1c085ef 100644 --- a/src/utils/ucc_parser.h +++ b/src/utils/ucc_parser.h @@ -66,7 +66,6 @@ typedef struct ucc_file_config ucc_file_config_t; #define UCC_CONFIG_TYPE_ULUNITS UCS_CONFIG_TYPE_ULUNITS #define UCC_CONFIG_TYPE_ENUM UCS_CONFIG_TYPE_ENUM #define UCC_CONFIG_TYPE_MEMUNITS UCS_CONFIG_TYPE_MEMUNITS -#define UCC_CONFIG_TYPE_ULUNITS UCS_CONFIG_TYPE_ULUNITS #define UCC_ULUNITS_AUTO UCS_ULUNITS_AUTO #define UCC_CONFIG_TYPE_BITMAP UCS_CONFIG_TYPE_BITMAP #define UCC_CONFIG_TYPE_MEMUNITS UCS_CONFIG_TYPE_MEMUNITS diff --git a/src/utils/ucc_rcache.h b/src/utils/ucc_rcache.h index 46993caacb..3e89396a93 100644 --- a/src/utils/ucc_rcache.h +++ b/src/utils/ucc_rcache.h @@ -1,5 +1,6 @@ /** - * Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * * See file LICENSE for terms. */ @@ -16,10 +17,23 @@ #define ucc_rcache_params_t ucs_rcache_params_t #define ucc_rcache_region_t ucs_rcache_region_t -#define ucc_rcache_destroy ucs_rcache_destroy -#define ucc_rcache_region_hold ucs_rcache_region_hold -#define ucc_rcache_region_put ucs_rcache_region_put -#define ucc_rcache_region_invalidate ucs_rcache_region_invalidate +static inline void ucc_rcache_set_default_params(ucs_rcache_params_t *rcache_params) +{ + rcache_params->region_struct_size = sizeof(ucs_rcache_region_t); + rcache_params->ucm_events = 0; + rcache_params->ucm_event_priority = 1000; + rcache_params->ops = NULL; + rcache_params->context = NULL; + rcache_params->flags = 0; + rcache_params->max_regions = UCS_MEMUNITS_INF; + rcache_params->max_size = UCS_MEMUNITS_INF; + rcache_params->max_unreleased = UCS_MEMUNITS_INF; +} + +#define ucc_rcache_destroy ucs_rcache_destroy +#define ucc_rcache_region_hold ucs_rcache_region_hold +#define ucc_rcache_region_put ucs_rcache_region_put +#define ucc_rcache_region_invalidate ucs_rcache_region_invalidate /* Wrapper functions for status conversion */ static inline ucc_status_t @@ -46,16 +60,17 @@ static inline ucc_status_t ucc_rcache_get(ucc_rcache_t *rcache, void *address, size_t length, void *arg, ucc_rcache_region_t **region_p) { + ucs_status_t status; + #ifdef UCS_HAVE_RCACHE_REGION_ALIGNMENT - return ucs_status_to_ucc_status(ucs_rcache_get( - rcache, address, length, - ucc_get_page_size(), - PROT_READ | PROT_WRITE, arg, region_p)); + status = ucs_rcache_get(rcache, address, length, ucc_get_page_size(), + PROT_READ | PROT_WRITE, arg, region_p); #else - return ucs_status_to_ucc_status(ucs_rcache_get( - rcache, address, length, - PROT_READ | PROT_WRITE, arg, region_p)); + status = ucs_rcache_get(rcache, address, length, + PROT_READ | PROT_WRITE, arg, region_p); #endif + + return ucs_status_to_ucc_status(status); } #endif diff --git a/test/gtest/coll/test_allgather.cc b/test/gtest/coll/test_allgather.cc index e1cacac5ac..2577bdf26c 100644 --- a/test/gtest/coll/test_allgather.cc +++ b/test/gtest/coll/test_allgather.cc @@ -296,7 +296,7 @@ INSTANTIATE_TEST_CASE_P( #endif ::testing::Values(1,3,8192), // count ::testing::Values(TEST_INPLACE, TEST_NO_INPLACE), - ::testing::Values("knomial", "ring", "neighbor")), + ::testing::Values("knomial", "ring", "neighbor", "bruck")), [](const testing::TestParamInfo& info) { std::string name; name += ucc_datatype_str(std::get<0>(info.param)); diff --git a/test/gtest/coll/test_reduce.cc b/test/gtest/coll/test_reduce.cc index 0f8bfc034f..2fb1cbc963 100644 --- a/test/gtest/coll/test_reduce.cc +++ b/test/gtest/coll/test_reduce.cc @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -283,55 +283,66 @@ template class test_reduce_avg_order : public test_reduce { template class test_reduce_dbt : public test_reduce { }; -#define TEST_DECLARE_WITH_ENV(_env, _n_procs) \ - { \ - UccJob job(_n_procs, UccJob::UCC_JOB_CTX_GLOBAL, _env); \ - UccTeam_h team = job.create_team(_n_procs); \ - int repeat = 3; \ - UccCollCtxVec ctxs; \ - std::vector mt = {UCC_MEMORY_TYPE_HOST}; \ - if (UCC_OK == ucc_mc_available(UCC_MEMORY_TYPE_CUDA)) { \ - mt.push_back(UCC_MEMORY_TYPE_CUDA); \ - } \ - if (UCC_OK == ucc_mc_available(UCC_MEMORY_TYPE_CUDA_MANAGED)) { \ - mt.push_back(UCC_MEMORY_TYPE_CUDA_MANAGED); \ - } \ - for (auto count : {5, 256, 65536}) { \ - for (auto inplace : {TEST_NO_INPLACE, TEST_INPLACE}) { \ - for (auto m : mt) { \ - CHECK_TYPE_OP_SKIP(TypeParam::dt, TypeParam::redop, m); \ - SET_MEM_TYPE(m); \ - this->set_inplace(inplace); \ - this->data_init(_n_procs, TypeParam::dt, count, ctxs, true); \ - UccReq req(team, ctxs); \ - CHECK_REQ_NOT_SUPPORTED_SKIP(req, this->data_fini(ctxs)); \ - for (auto i = 0; i < repeat; i++) { \ - req.start(); \ - req.wait(); \ - EXPECT_EQ(true, this->data_validate(ctxs)); \ - this->reset(ctxs); \ - } \ - this->data_fini(ctxs); \ - } \ - } \ - } \ +template class test_reduce_2step : public test_reduce { +}; + +#define TEST_DECLARE_WITH_ENV(_env, _n_procs, _persistent) \ + { \ + UccJob job(_n_procs, UccJob::UCC_JOB_CTX_GLOBAL, _env); \ + UccTeam_h team = job.create_team(_n_procs); \ + int repeat = _persistent ? 3 : 1; \ + UccCollCtxVec ctxs; \ + std::vector mt = {UCC_MEMORY_TYPE_HOST}; \ + if (UCC_OK == ucc_mc_available(UCC_MEMORY_TYPE_CUDA)) { \ + mt.push_back(UCC_MEMORY_TYPE_CUDA); \ + } \ + if (UCC_OK == ucc_mc_available(UCC_MEMORY_TYPE_CUDA_MANAGED)) { \ + mt.push_back(UCC_MEMORY_TYPE_CUDA_MANAGED); \ + } \ + for (auto count : {5, 256, 65536}) { \ + for (auto inplace : {TEST_NO_INPLACE, TEST_INPLACE}) { \ + for (auto m : mt) { \ + CHECK_TYPE_OP_SKIP(TypeParam::dt, TypeParam::redop, m); \ + SET_MEM_TYPE(m); \ + this->set_inplace(inplace); \ + this->data_init(_n_procs, TypeParam::dt, count, ctxs, \ + _persistent); \ + UccReq req(team, ctxs); \ + CHECK_REQ_NOT_SUPPORTED_SKIP(req, this->data_fini(ctxs)); \ + for (auto i = 0; i < repeat; i++) { \ + req.start(); \ + req.wait(); \ + EXPECT_EQ(true, this->data_validate(ctxs)); \ + this->reset(ctxs); \ + } \ + this->data_fini(ctxs); \ + } \ + } \ + } \ } TYPED_TEST_CASE(test_reduce_avg_order, CollReduceTypeOpsAvg); TYPED_TEST_CASE(test_reduce_dbt, CollReduceTypeOpsHost); +TYPED_TEST_CASE(test_reduce_2step, CollReduceTypeOpsHost); -ucc_job_env_t post_op_env = {{"UCC_TL_UCP_REDUCE_AVG_PRE_OP", "0"}}; -ucc_job_env_t reduce_dbt_env = {{"UCC_TL_UCP_TUNE", "reduce:@dbt:0-inf:inf"}, - {"UCC_CLS", "basic"}}; +ucc_job_env_t post_op_env = {{"UCC_TL_UCP_REDUCE_AVG_PRE_OP", "0"}}; +ucc_job_env_t reduce_dbt_env = {{"UCC_TL_UCP_TUNE", "reduce:@dbt:0-inf:inf"}, + {"UCC_CLS", "basic"}}; +ucc_job_env_t reduce_2step_env = {{"UCC_CL_HIER_TUNE", "reduce:@2step:0-inf:inf"}, + {"UCC_CLS", "all"}}; TYPED_TEST(test_reduce_avg_order, avg_post_op) { - TEST_DECLARE_WITH_ENV(post_op_env, 15); + TEST_DECLARE_WITH_ENV(post_op_env, 15, true); } TYPED_TEST(test_reduce_dbt, reduce_dbt_shift) { - TEST_DECLARE_WITH_ENV(reduce_dbt_env, 15); + TEST_DECLARE_WITH_ENV(reduce_dbt_env, 15, true); } TYPED_TEST(test_reduce_dbt, reduce_dbt_mirror) { - TEST_DECLARE_WITH_ENV(reduce_dbt_env, 16); + TEST_DECLARE_WITH_ENV(reduce_dbt_env, 16, true); +} + +TYPED_TEST(test_reduce_2step, 2step) { + TEST_DECLARE_WITH_ENV(reduce_2step_env, 16, false); } diff --git a/test/mpi/main.cc b/test/mpi/main.cc index f4a571fa14..716d1d4b50 100644 --- a/test/mpi/main.cc +++ b/test/mpi/main.cc @@ -1,5 +1,5 @@ /** - * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) Advanced Micro Devices, Inc. 2023. ALL RIGHTS RESERVED. * * See file LICENSE for terms. @@ -574,8 +574,8 @@ void ProcessArgs(int argc, char** argv) int main(int argc, char *argv[]) { - int failed = 0; - int total_done_skipped_failed[ucc_ilog2(UCC_COLL_TYPE_LAST) + 1][4] = {0}; + int failed = 0; + int total_done_skipped_failed[ucc_ilog2(UCC_COLL_TYPE_LAST) + 1][4]; std::chrono::steady_clock::time_point begin; int size, required, provided, completed, rank; UccTestMpi *test; @@ -583,6 +583,8 @@ int main(int argc, char *argv[]) std::string err; begin = std::chrono::steady_clock::now(); + memset(total_done_skipped_failed, 0, + sizeof(total_done_skipped_failed)); try { ProcessArgs(argc, argv); } catch (const std::string &s) { diff --git a/test/mpi/test_allgather.cc b/test/mpi/test_allgather.cc index ebca8c4c95..a98bbd8ee4 100644 --- a/test/mpi/test_allgather.cc +++ b/test/mpi/test_allgather.cc @@ -51,7 +51,7 @@ ucc_status_t TestAllgather::set_input(int iter_persistent) size_t single_rank_count = msgsize / dt_size; size_t single_rank_size = single_rank_count * dt_size; int rank; - void *buf, *check; + void *buf; this->iter_persistent = iter_persistent; MPI_Comm_rank(team.comm, &rank); @@ -60,12 +60,9 @@ ucc_status_t TestAllgather::set_input(int iter_persistent) } else { buf = sbuf; } - check = PTR_OFFSET(check_buf, rank * single_rank_size); init_buffer(buf, single_rank_count, dt, mem_type, rank * (iter_persistent + 1)); - UCC_CHECK(ucc_mc_memcpy(check, buf, single_rank_size, - UCC_MEMORY_TYPE_HOST, mem_type)); return UCC_OK; } @@ -83,7 +80,6 @@ ucc_status_t TestAllgather::check() i * (iter_persistent + 1)); } - return compare_buffers(rbuf, check_buf, single_rank_count * size, dt, mem_type); } diff --git a/test/mpi/test_bcast.cc b/test/mpi/test_bcast.cc index 080cbb436f..d5d1c92a7c 100644 --- a/test/mpi/test_bcast.cc +++ b/test/mpi/test_bcast.cc @@ -49,8 +49,6 @@ ucc_status_t TestBcast::set_input(int iter_persistent) MPI_Comm_rank(team.comm, &rank); if (rank == root) { init_buffer(sbuf, count, dt, mem_type, rank * (iter_persistent + 1)); - UCC_CHECK(ucc_mc_memcpy(check_buf, sbuf, count * dt_size, - UCC_MEMORY_TYPE_HOST, mem_type)); } return UCC_OK; } @@ -61,9 +59,11 @@ ucc_status_t TestBcast::check() int rank; MPI_Comm_rank(team.comm, &rank); + if (rank == root) { + return UCC_OK; + } + init_buffer(check_buf, count, dt, UCC_MEMORY_TYPE_HOST, root * (iter_persistent + 1)); - return (rank == root) - ? UCC_OK - : compare_buffers(sbuf, check_buf, count, dt, mem_type); + return compare_buffers(sbuf, check_buf, count, dt, mem_type); } diff --git a/test/mpi/test_gather.cc b/test/mpi/test_gather.cc index 4b87fe5397..0f21455fab 100644 --- a/test/mpi/test_gather.cc +++ b/test/mpi/test_gather.cc @@ -75,8 +75,9 @@ ucc_status_t TestGather::set_input(int iter_persistent) size_t single_rank_count = msgsize / dt_size; size_t single_rank_size = single_rank_count * dt_size; int rank; - void *buf, *check; + void *buf; + this->iter_persistent = iter_persistent; MPI_Comm_rank(team.comm, &rank); if (rank == root) { if (inplace) { @@ -87,34 +88,30 @@ ucc_status_t TestGather::set_input(int iter_persistent) } else { buf = sbuf; } - check = PTR_OFFSET(check_buf, rank * single_rank_size); - init_buffer(buf, single_rank_count, dt, mem_type, rank * (iter_persistent + 1)); - UCC_CHECK(ucc_mc_memcpy(check, buf, single_rank_size, - UCC_MEMORY_TYPE_HOST, mem_type)); return UCC_OK; } ucc_status_t TestGather::check() { - size_t single_rank_count = msgsize / ucc_dt_size(dt); - MPI_Datatype mpi_dt = ucc_dt_to_mpi(dt); - MPI_Request req; - int size, rank, completed; + int size, rank, i; + size_t dt_size, single_rank_count; MPI_Comm_size(team.comm, &size); MPI_Comm_rank(team.comm, &rank); + if (rank != root) { + return UCC_OK; + } - MPI_Iallgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, check_buf, - single_rank_count, mpi_dt, team.comm, &req); - do { - MPI_Test(&req, &completed, MPI_STATUS_IGNORE); - ucc_context_progress(team.ctx); - } while(!completed); + dt_size = ucc_dt_size(dt); + single_rank_count = msgsize / dt_size; + for (i = 0; i < size; i++) { + init_buffer(PTR_OFFSET(check_buf, i * single_rank_count * dt_size), + single_rank_count, dt, UCC_MEMORY_TYPE_HOST, + i * (iter_persistent + 1)); + } - return (rank != root) - ? UCC_OK - : compare_buffers(rbuf, check_buf, single_rank_count * size, dt, - mem_type); + return compare_buffers(rbuf, check_buf, single_rank_count * size, dt, + mem_type); } diff --git a/test/mpi/test_gatherv.cc b/test/mpi/test_gatherv.cc index 7468a56307..445b9faa14 100644 --- a/test/mpi/test_gatherv.cc +++ b/test/mpi/test_gatherv.cc @@ -106,8 +106,9 @@ ucc_status_t TestGatherv::set_input(int iter_persistent) { size_t dt_size = ucc_dt_size(dt); int rank; - void *buf, *check; + void *buf; + this->iter_persistent = iter_persistent; MPI_Comm_rank(team.comm, &rank); if (rank == root) { if (inplace) { @@ -118,11 +119,8 @@ ucc_status_t TestGatherv::set_input(int iter_persistent) } else { buf = sbuf; } - check = PTR_OFFSET(check_buf, displacements[rank] * dt_size); - init_buffer(buf, counts[rank], dt, mem_type, rank * (iter_persistent + 1)); - UCC_CHECK(ucc_mc_memcpy(check, buf, counts[rank] * dt_size, - UCC_MEMORY_TYPE_HOST, mem_type)); + return UCC_OK; } @@ -138,21 +136,21 @@ TestGatherv::~TestGatherv() ucc_status_t TestGatherv::check() { - size_t count = msgsize / ucc_dt_size(dt); - MPI_Datatype mpi_dt = ucc_dt_to_mpi(dt); - MPI_Request req; - int size, rank, completed; + size_t count = msgsize / ucc_dt_size(dt); + int size, rank, i; MPI_Comm_size(team.comm, &size); MPI_Comm_rank(team.comm, &rank); - MPI_Iallgatherv(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, check_buf, - (int *)counts, (int *)displacements, mpi_dt, team.comm, - &req); - do { - MPI_Test(&req, &completed, MPI_STATUS_IGNORE); - ucc_context_progress(team.ctx); - } while(!completed); + if (rank != root) { + return UCC_OK; + } + + for (i = 0; i < size; i++) { + init_buffer(PTR_OFFSET(check_buf, displacements[i] * ucc_dt_size(dt)), + counts[i], dt, UCC_MEMORY_TYPE_HOST, + i * (iter_persistent + 1)); + } return (rank != root) ? UCC_OK diff --git a/test/mpi/test_scatter.cc b/test/mpi/test_scatter.cc index 016ed7465b..4d4438b635 100644 --- a/test/mpi/test_scatter.cc +++ b/test/mpi/test_scatter.cc @@ -25,7 +25,6 @@ TestScatter::TestScatter(ucc_test_team_t &_team, TestCaseParams ¶ms) : TEST_SKIP_MEM_LIMIT, team.comm)) { return; } - if (rank == root) { UCC_CHECK(ucc_mc_alloc(&sbuf_mc_header, msgsize * size, mem_type)); sbuf = sbuf_mc_header->addr;