Skip to content

Commit

Permalink
Merge pull request openucx#919 from Sergei-Lebedev/topic/cl_hier_max_…
Browse files Browse the repository at this point in the history
…tls_sbgp

CL/HIER: check number of TLs per SBGP
  • Loading branch information
bureddy authored Feb 6, 2024
2 parents e66574d + fa2c434 commit c13d26c
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 10 deletions.
48 changes: 43 additions & 5 deletions src/components/cl/hier/cl_hier_team.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down Expand Up @@ -43,6 +43,11 @@ UCC_CLASS_INIT_FUNC(ucc_cl_hier_team_t, ucc_base_context_t *cl_context,
ucc_config_names_array_t *tls;
ucc_subset_t subset;
struct ucc_team_team_desc *d;
ucc_tl_context_t *tl_ctx;
ucc_tl_lib_t *tl_lib;
ucc_base_lib_attr_t attr;


if (!params->team->topo) {
cl_debug(cl_context->lib,
"can't create hier team without topology data");
Expand Down Expand Up @@ -74,18 +79,51 @@ UCC_CLASS_INIT_FUNC(ucc_cl_hier_team_t, ucc_base_context_t *cl_context,
hs->n_tls = 0;
tls = &lib->cfg.sbgp_tls[i].array;
for (j = 0; j < tls->count; j++) {
if (hs->n_tls == CL_HIER_MAX_SBGP_TLS) {
cl_debug(cl_context->lib,
"skipping tl context %s for %s sbgp: "
"max number of TLs per SBGP is reached",
tls->names[j], ucc_sbgp_str(hs->sbgp_type));
continue;
}
status = ucc_tl_context_get(ctx->super.super.ucc_context,
tls->names[j],
&hs->tl_ctxs[hs->n_tls]);
if (UCC_OK != status) {
cl_debug(cl_context->lib,
"tl context %s is not available for sbgp %s",
tls->names[j], ucc_sbgp_str(hs->sbgp_type));
} else {
hs->n_tls++;
n_sbgp_teams++;
ucc_assert(hs->n_tls <= CL_HIER_MAX_SBGP_TLS);
continue;
}
attr.mask = UCC_BASE_LIB_ATTR_FIELD_MIN_TEAM_SIZE |
UCC_BASE_LIB_ATTR_FIELD_MAX_TEAM_SIZE;
tl_ctx = hs->tl_ctxs[hs->n_tls];
tl_lib = ucc_derived_of(tl_ctx->super.lib, ucc_tl_lib_t);
status = tl_lib->iface->lib.get_attr(tl_ctx->super.lib,
&attr);
if (status != UCC_OK) {
cl_debug(cl_context->lib,
"failed to get attributes for tl context %s",
tls->names[j]);
ucc_tl_context_put(tl_ctx);
continue;
}

if (hs->sbgp->group_size < attr.min_team_size ||
hs->sbgp->group_size > attr.max_team_size) {
cl_debug(cl_context->lib,
"tl context %s is not suitable for sbgp %s"
"sbgp: sbgp size %d is not in range [%d; %d]",
tls->names[j], ucc_sbgp_str(hs->sbgp_type),
hs->sbgp->group_size,
attr.min_team_size, attr.max_team_size);
ucc_tl_context_put(tl_ctx);
continue;
}

hs->n_tls++;
n_sbgp_teams++;
ucc_assert(hs->n_tls <= CL_HIER_MAX_SBGP_TLS);
}
}
}
Expand Down
4 changes: 3 additions & 1 deletion src/components/ec/cuda/ec_cuda.c
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,9 @@ ucc_status_t ucc_ec_cuda_get_resources(ucc_ec_cuda_resources_t **resources)
#else
status = CUDADRV_FUNC(cuCtxGetId(cu_ctx, &cu_ctx_id));
if (ucc_unlikely(status != UCC_OK)) {
ec_error(&ucc_ec_cuda.super, "failed to get currect CUDA context ID");
/* worakround for pytorch, progress thread doesn't have cuda context for GPU 0*/
cu_ctx_id = 0x12345;
ec_debug(&ucc_ec_cuda.super, "failed to get currect CUDA context ID");
}
#endif

Expand Down
4 changes: 3 additions & 1 deletion src/components/mc/cuda/mc_cuda.c
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,9 @@ ucc_status_t ucc_mc_cuda_get_resources(ucc_mc_cuda_resources_t **resources)
#else
status = CUDADRV_FUNC(cuCtxGetId(cu_ctx, &cu_ctx_id));
if (ucc_unlikely(status != UCC_OK)) {
mc_error(&ucc_mc_cuda.super, "failed to get currect CUDA context ID");
/* worakround for pytorch, progress thread doesn't have cuda context for GPU 0*/
cu_ctx_id = 0x12345;
mc_debug(&ucc_mc_cuda.super, "failed to get currect CUDA context ID");
}
#endif

Expand Down
8 changes: 5 additions & 3 deletions test/mpi/main.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) Advanced Micro Devices, Inc. 2023. ALL RIGHTS RESERVED.
*
* See file LICENSE for terms.
Expand Down Expand Up @@ -574,15 +574,17 @@ void ProcessArgs(int argc, char** argv)

int main(int argc, char *argv[])
{
int failed = 0;
int total_done_skipped_failed[ucc_ilog2(UCC_COLL_TYPE_LAST) + 1][4] = {0};
int failed = 0;
int total_done_skipped_failed[ucc_ilog2(UCC_COLL_TYPE_LAST) + 1][4];
std::chrono::steady_clock::time_point begin;
int size, required, provided, completed, rank;
UccTestMpi *test;
MPI_Request req;
std::string err;

begin = std::chrono::steady_clock::now();
memset(total_done_skipped_failed, 0,
sizeof(total_done_skipped_failed));
try {
ProcessArgs(argc, argv);
} catch (const std::string &s) {
Expand Down

0 comments on commit c13d26c

Please sign in to comment.