From 900626e9decc345a989b7a3095209c32fdb69a1f Mon Sep 17 00:00:00 2001 From: Sergey Lebedev Date: Tue, 6 Feb 2024 12:21:50 +0000 Subject: [PATCH] CL/HIER: check number of TLs per SBGP --- src/components/cl/hier/cl_hier_team.c | 48 ++++++++++++++++++++++++--- test/mpi/main.cc | 8 +++-- 2 files changed, 48 insertions(+), 8 deletions(-) diff --git a/src/components/cl/hier/cl_hier_team.c b/src/components/cl/hier/cl_hier_team.c index 31f0e9f707..8457e3db83 100644 --- a/src/components/cl/hier/cl_hier_team.c +++ b/src/components/cl/hier/cl_hier_team.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -43,6 +43,11 @@ UCC_CLASS_INIT_FUNC(ucc_cl_hier_team_t, ucc_base_context_t *cl_context, ucc_config_names_array_t *tls; ucc_subset_t subset; struct ucc_team_team_desc *d; + ucc_tl_context_t *tl_ctx; + ucc_tl_lib_t *tl_lib; + ucc_base_lib_attr_t attr; + + if (!params->team->topo) { cl_debug(cl_context->lib, "can't create hier team without topology data"); @@ -74,6 +79,13 @@ UCC_CLASS_INIT_FUNC(ucc_cl_hier_team_t, ucc_base_context_t *cl_context, hs->n_tls = 0; tls = &lib->cfg.sbgp_tls[i].array; for (j = 0; j < tls->count; j++) { + if (hs->n_tls == CL_HIER_MAX_SBGP_TLS) { + cl_debug(cl_context->lib, + "skipping tl context %s for %s sbgp: " + "max number of TLs per SBGP is reached", + tls->names[j], ucc_sbgp_str(hs->sbgp_type)); + continue; + } status = ucc_tl_context_get(ctx->super.super.ucc_context, tls->names[j], &hs->tl_ctxs[hs->n_tls]); @@ -81,11 +93,37 @@ UCC_CLASS_INIT_FUNC(ucc_cl_hier_team_t, ucc_base_context_t *cl_context, cl_debug(cl_context->lib, "tl context %s is not available for sbgp %s", tls->names[j], ucc_sbgp_str(hs->sbgp_type)); - } else { - hs->n_tls++; - n_sbgp_teams++; - ucc_assert(hs->n_tls <= CL_HIER_MAX_SBGP_TLS); + continue; } + attr.mask = UCC_BASE_LIB_ATTR_FIELD_MIN_TEAM_SIZE | + UCC_BASE_LIB_ATTR_FIELD_MAX_TEAM_SIZE; + tl_ctx = hs->tl_ctxs[hs->n_tls]; + tl_lib = ucc_derived_of(tl_ctx->super.lib, ucc_tl_lib_t); + status = tl_lib->iface->lib.get_attr(tl_ctx->super.lib, + &attr); + if (status != UCC_OK) { + cl_debug(cl_context->lib, + "failed to get attributes for tl context %s", + tls->names[j]); + ucc_tl_context_put(tl_ctx); + continue; + } + + if (hs->sbgp->group_size < attr.min_team_size || + hs->sbgp->group_size > attr.max_team_size) { + cl_debug(cl_context->lib, + "tl context %s is not suitable for sbgp %s" + "sbgp: sbgp size %d is not in range [%d; %d]", + tls->names[j], ucc_sbgp_str(hs->sbgp_type), + hs->sbgp->group_size, + attr.min_team_size, attr.max_team_size); + ucc_tl_context_put(tl_ctx); + continue; + } + + hs->n_tls++; + n_sbgp_teams++; + ucc_assert(hs->n_tls <= CL_HIER_MAX_SBGP_TLS); } } } diff --git a/test/mpi/main.cc b/test/mpi/main.cc index f4a571fa14..716d1d4b50 100644 --- a/test/mpi/main.cc +++ b/test/mpi/main.cc @@ -1,5 +1,5 @@ /** - * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) Advanced Micro Devices, Inc. 2023. ALL RIGHTS RESERVED. * * See file LICENSE for terms. @@ -574,8 +574,8 @@ void ProcessArgs(int argc, char** argv) int main(int argc, char *argv[]) { - int failed = 0; - int total_done_skipped_failed[ucc_ilog2(UCC_COLL_TYPE_LAST) + 1][4] = {0}; + int failed = 0; + int total_done_skipped_failed[ucc_ilog2(UCC_COLL_TYPE_LAST) + 1][4]; std::chrono::steady_clock::time_point begin; int size, required, provided, completed, rank; UccTestMpi *test; @@ -583,6 +583,8 @@ int main(int argc, char *argv[]) std::string err; begin = std::chrono::steady_clock::now(); + memset(total_done_skipped_failed, 0, + sizeof(total_done_skipped_failed)); try { ProcessArgs(argc, argv); } catch (const std::string &s) {