Skip to content

Commit

Permalink
CL/HIER: check number of TLs per SBGP
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergei-Lebedev committed Feb 6, 2024
1 parent e66574d commit 061d289
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 8 deletions.
48 changes: 43 additions & 5 deletions src/components/cl/hier/cl_hier_team.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down Expand Up @@ -43,6 +43,11 @@ UCC_CLASS_INIT_FUNC(ucc_cl_hier_team_t, ucc_base_context_t *cl_context,
ucc_config_names_array_t *tls;
ucc_subset_t subset;
struct ucc_team_team_desc *d;
ucc_tl_context_t *tl_ctx;
ucc_tl_lib_t *tl_lib;
ucc_base_lib_attr_t attr;


if (!params->team->topo) {
cl_debug(cl_context->lib,
"can't create hier team without topology data");
Expand Down Expand Up @@ -74,18 +79,51 @@ UCC_CLASS_INIT_FUNC(ucc_cl_hier_team_t, ucc_base_context_t *cl_context,
hs->n_tls = 0;
tls = &lib->cfg.sbgp_tls[i].array;
for (j = 0; j < tls->count; j++) {
if (hs->n_tls == CL_HIER_MAX_SBGP_TLS) {
cl_debug(cl_context->lib,
"skipping tl context %s for %s sbgp: "
"max number of TLs per SBGP is reached",
tls->names[j], ucc_sbgp_str(hs->sbgp_type));
continue;
}
status = ucc_tl_context_get(ctx->super.super.ucc_context,
tls->names[j],
&hs->tl_ctxs[hs->n_tls]);
if (UCC_OK != status) {
cl_debug(cl_context->lib,
"tl context %s is not available for sbgp %s",
tls->names[j], ucc_sbgp_str(hs->sbgp_type));
} else {
hs->n_tls++;
n_sbgp_teams++;
ucc_assert(hs->n_tls <= CL_HIER_MAX_SBGP_TLS);
continue;
}
attr.mask = UCC_BASE_LIB_ATTR_FIELD_MIN_TEAM_SIZE |
UCC_BASE_LIB_ATTR_FIELD_MAX_TEAM_SIZE;
tl_ctx = hs->tl_ctxs[hs->n_tls];
tl_lib = ucc_derived_of(tl_ctx->super.lib, ucc_tl_lib_t);
status = tl_lib->iface->lib.get_attr(tl_ctx->super.lib,
&attr);
if (status != UCC_OK) {
cl_debug(cl_context->lib,
"failed to get attributes for tl context %s",
tls->names[j]);
ucc_tl_context_put(tl_ctx);
continue;
}

if (hs->sbgp->group_size < attr.min_team_size ||
hs->sbgp->group_size > attr.max_team_size) {
cl_debug(cl_context->lib,
"tl context %s is not suitable for sbgp %s"
"sbgp: sbgp size %d is not in range [%d; %d]",
tls->names[j], ucc_sbgp_str(hs->sbgp_type),
hs->sbgp->group_size,
attr.min_team_size, attr.max_team_size);
ucc_tl_context_put(tl_ctx);
continue;
}

hs->n_tls++;
n_sbgp_teams++;
ucc_assert(hs->n_tls <= CL_HIER_MAX_SBGP_TLS);
}
}
}
Expand Down
8 changes: 5 additions & 3 deletions test/mpi/main.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) Advanced Micro Devices, Inc. 2023. ALL RIGHTS RESERVED.
*
* See file LICENSE for terms.
Expand Down Expand Up @@ -574,15 +574,17 @@ void ProcessArgs(int argc, char** argv)

int main(int argc, char *argv[])
{
int failed = 0;
int total_done_skipped_failed[ucc_ilog2(UCC_COLL_TYPE_LAST) + 1][4] = {0};
int failed = 0;
int total_done_skipped_failed[ucc_ilog2(UCC_COLL_TYPE_LAST) + 1][4];
std::chrono::steady_clock::time_point begin;
int size, required, provided, completed, rank;
UccTestMpi *test;
MPI_Request req;
std::string err;

begin = std::chrono::steady_clock::now();
memset(total_done_skipped_failed, 0,
sizeof(total_done_skipped_failed));
try {
ProcessArgs(argc, argv);
} catch (const std::string &s) {
Expand Down

0 comments on commit 061d289

Please sign in to comment.