From d7cd5a340b3b9cfba052f06b9ad3d952586e572c Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Wed, 3 Jul 2024 18:35:02 +0300 Subject: [PATCH] UTIL: Set rcache alignment to UCS_PGT_ADDR_ALIGN --- src/components/tl/mlx5/tl_mlx5_team.c | 87 ++++++++++++++++----------- src/utils/ucc_rcache.h | 2 +- 2 files changed, 53 insertions(+), 36 deletions(-) diff --git a/src/components/tl/mlx5/tl_mlx5_team.c b/src/components/tl/mlx5/tl_mlx5_team.c index edc0b95d90..28139deb66 100644 --- a/src/components/tl/mlx5/tl_mlx5_team.c +++ b/src/components/tl/mlx5/tl_mlx5_team.c @@ -78,8 +78,9 @@ UCC_CLASS_INIT_FUNC(ucc_tl_mlx5_team_t, ucc_base_context_t *tl_context, self->local_mcast_team_ready = 0; if (ctx->mcast.mcast_ctx_ready) { - status = ucc_tl_mlx5_mcast_team_init(tl_context, &(self->mcast), &(ctx->mcast), - params, &(UCC_TL_MLX5_TEAM_LIB(self)->cfg.mcast_conf)); + status = ucc_tl_mlx5_mcast_team_init( + tl_context, &(self->mcast), &(ctx->mcast), params, + &(UCC_TL_MLX5_TEAM_LIB(self)->cfg.mcast_conf)); if (UCC_OK != status) { tl_warn(tl_context->lib, "mcast team init failed"); } else { @@ -119,7 +120,7 @@ ucc_status_t ucc_tl_mlx5_team_destroy(ucc_base_team_t *tl_team) static inline ucc_status_t ucc_tl_mlx5_a2a_team_test(ucc_base_team_t *team) { - ucc_tl_mlx5_team_t *tl_team = ucc_derived_of(team, ucc_tl_mlx5_team_t); + ucc_tl_mlx5_team_t *tl_team = ucc_derived_of(team, ucc_tl_mlx5_team_t); switch (tl_team->a2a_state) { case TL_MLX5_TEAM_STATE_ALLTOALL_INIT: @@ -154,16 +155,15 @@ static inline ucc_status_t ucc_tl_mlx5_a2a_team_test(ucc_base_team_t *team) ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team) { - ucc_tl_mlx5_team_t *tl_team = ucc_derived_of(team, ucc_tl_mlx5_team_t); - ucc_team_t *core_team = UCC_TL_CORE_TEAM(tl_team); - ucc_subset_t subset = {.map = UCC_TL_TEAM_MAP(tl_team), - .myrank = UCC_TL_TEAM_RANK(tl_team)}; - ucc_status_t a2a_status = UCC_OK; - ucc_status_t mcast_status = UCC_OK; - ucc_tl_mlx5_mcast_coll_comm_t *comm = NULL; + ucc_tl_mlx5_team_t *tl_team = ucc_derived_of(team, ucc_tl_mlx5_team_t); + ucc_team_t *core_team = UCC_TL_CORE_TEAM(tl_team); + ucc_subset_t subset = {.map = UCC_TL_TEAM_MAP(tl_team), + .myrank = UCC_TL_TEAM_RANK(tl_team)}; + ucc_status_t a2a_status = UCC_OK; + ucc_status_t mcast_status = UCC_OK; + ucc_tl_mlx5_mcast_coll_comm_t *comm = NULL; ucc_status_t status; - if (tl_team->global_sync_req != NULL) { status = ucc_service_coll_test(tl_team->global_sync_req); if (status < 0) { @@ -181,15 +181,18 @@ ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team) tl_team->global_sync_req = NULL; if (tl_team->mcast_state == TL_MLX5_TEAM_STATE_MCAST_CTX_CHECK && - tl_team->a2a_state == TL_MLX5_TEAM_STATE_ALLTOALL_CTX_CHECK ) { - tl_team->a2a_status.global = tl_team->global_status_array[UCC_TL_MLX5_A2A_STATUS_INDEX]; - tl_team->a2a_state = TL_MLX5_TEAM_STATE_ALLTOALL_INIT; + tl_team->a2a_state == TL_MLX5_TEAM_STATE_ALLTOALL_CTX_CHECK) { + tl_team->a2a_status.global = + tl_team->global_status_array[UCC_TL_MLX5_A2A_STATUS_INDEX]; + tl_team->a2a_state = TL_MLX5_TEAM_STATE_ALLTOALL_INIT; - if (tl_team->global_status_array[UCC_TL_MLX5_MCAST_STATUS_INDEX] != UCC_OK) { + if (tl_team->global_status_array[UCC_TL_MLX5_MCAST_STATUS_INDEX] != + UCC_OK) { /* mcast context is not available for some of the team members so we cannot create * mcast team */ tl_debug(UCC_TL_TEAM_LIB(tl_team), - "some of the ranks do not have mcast context available so no mcast team is created"); + "some of the ranks do not have mcast context " + "available so no mcast team is created"); if (tl_team->local_mcast_team_ready) { comm = tl_team->mcast->mcast_comm; @@ -215,7 +218,8 @@ ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team) return UCC_INPROGRESS; } else { - if (tl_team->global_status_array[UCC_TL_MLX5_A2A_STATUS_INDEX] != UCC_OK) { + if (tl_team->global_status_array[UCC_TL_MLX5_A2A_STATUS_INDEX] != + UCC_OK) { //a2a team not avail for some of nodes so disable it if (tl_team->a2a_state == TL_MLX5_TEAM_STATE_ALLTOALL_READY) { // free the resources @@ -224,7 +228,8 @@ ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team) tl_team->a2a_state = TL_MLX5_TEAM_STATE_ALLTOALL_NOT_AVAILABLE; } - if (tl_team->global_status_array[UCC_TL_MLX5_MCAST_STATUS_INDEX] != UCC_OK) { + if (tl_team->global_status_array[UCC_TL_MLX5_MCAST_STATUS_INDEX] != + UCC_OK) { //mcast team not avail for some of nodes so disable it if (tl_team->mcast_state == TL_MLX5_TEAM_STATE_MCAST_READY) { // free the resources @@ -233,16 +238,22 @@ ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team) tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_NOT_AVAILABLE; } - tl_debug(team->context->lib, "team %p: MCAST component is %s ALLTOALL component is %s", - team, (tl_team->mcast_state == TL_MLX5_TEAM_STATE_MCAST_READY)?"ENABLED":"DISABLED", - (tl_team->a2a_state == TL_MLX5_TEAM_STATE_ALLTOALL_READY)?"ENABLED":"DISABLED"); + tl_debug(team->context->lib, + "team %p: MCAST component is %s ALLTOALL component is %s", + team, + (tl_team->mcast_state == TL_MLX5_TEAM_STATE_MCAST_READY) + ? "ENABLED" + : "DISABLED", + (tl_team->a2a_state == TL_MLX5_TEAM_STATE_ALLTOALL_READY) + ? "ENABLED" + : "DISABLED"); } return UCC_OK; } ucc_assert(tl_team->global_sync_req == NULL); - + if (tl_team->mcast_state == TL_MLX5_TEAM_STATE_MCAST_CTX_CHECK && tl_team->a2a_state == TL_MLX5_TEAM_STATE_ALLTOALL_CTX_CHECK) { // check if ctx is ready for a2a and mcast @@ -274,9 +285,13 @@ ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team) } tl_team->local_status_array[UCC_TL_MLX5_A2A_STATUS_INDEX] = - (tl_team->a2a_state == TL_MLX5_TEAM_STATE_ALLTOALL_READY) ? UCC_OK : UCC_ERR_NO_RESOURCE; + (tl_team->a2a_state == TL_MLX5_TEAM_STATE_ALLTOALL_READY) + ? UCC_OK + : UCC_ERR_NO_RESOURCE; tl_team->local_status_array[UCC_TL_MLX5_MCAST_STATUS_INDEX] = - (tl_team->mcast_state == TL_MLX5_TEAM_STATE_MCAST_READY) ? UCC_OK : UCC_ERR_NO_RESOURCE; + (tl_team->mcast_state == TL_MLX5_TEAM_STATE_MCAST_READY) + ? UCC_OK + : UCC_ERR_NO_RESOURCE; tl_debug(UCC_TL_TEAM_LIB(tl_team), "posting global status, local status: ALLTOALL %d MCAST %d", @@ -284,27 +299,27 @@ ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team) (tl_team->mcast_state == TL_MLX5_TEAM_STATE_MCAST_READY)); initial_sync_post: - status = ucc_service_allreduce( - core_team, tl_team->local_status_array, tl_team->global_status_array, - UCC_DT_INT32, UCC_TL_MLX5_FEATURES_COUNT, UCC_OP_MIN, subset, &tl_team->global_sync_req); + status = ucc_service_allreduce(core_team, tl_team->local_status_array, + tl_team->global_status_array, UCC_DT_INT32, + UCC_TL_MLX5_FEATURES_COUNT, UCC_OP_MIN, + subset, &tl_team->global_sync_req); if (status < 0) { - tl_debug(UCC_TL_TEAM_LIB(tl_team), - "failed to collect global status"); + tl_debug(UCC_TL_TEAM_LIB(tl_team), "failed to collect global status"); return status; } return UCC_INPROGRESS; } -ucc_status_t ucc_tl_mlx5_team_get_scores(ucc_base_team_t * tl_team, +ucc_status_t ucc_tl_mlx5_team_get_scores(ucc_base_team_t *tl_team, ucc_coll_score_t **score_p) { ucc_tl_mlx5_team_t *team = ucc_derived_of(tl_team, ucc_tl_mlx5_team_t); ucc_base_context_t *ctx = UCC_TL_TEAM_CTX(team); ucc_base_lib_t *lib = UCC_TL_TEAM_LIB(team); ucc_memory_type_t mt[2] = {UCC_MEMORY_TYPE_HOST, UCC_MEMORY_TYPE_CUDA}; - ucc_coll_score_t *score; - ucc_status_t status; + ucc_coll_score_t *score; + ucc_status_t status; ucc_coll_score_team_info_t team_info; team_info.alg_fn = NULL; @@ -313,9 +328,11 @@ ucc_status_t ucc_tl_mlx5_team_get_scores(ucc_base_team_t * tl_team, team_info.num_mem_types = 2; team_info.supported_mem_types = mt; team_info.supported_colls = - (UCC_COLL_TYPE_ALLTOALL * (team->a2a_state == TL_MLX5_TEAM_STATE_ALLTOALL_READY) * 0) | - UCC_COLL_TYPE_BCAST * (team->mcast_state == TL_MLX5_TEAM_STATE_MCAST_READY); - team_info.size = UCC_TL_TEAM_SIZE(team); + (UCC_COLL_TYPE_ALLTOALL * + (team->a2a_state == TL_MLX5_TEAM_STATE_ALLTOALL_READY)) | + UCC_COLL_TYPE_BCAST * + (team->mcast_state == TL_MLX5_TEAM_STATE_MCAST_READY); + team_info.size = UCC_TL_TEAM_SIZE(team); status = ucc_coll_score_build_default( tl_team, UCC_TL_MLX5_DEFAULT_SCORE, ucc_tl_mlx5_coll_init, diff --git a/src/utils/ucc_rcache.h b/src/utils/ucc_rcache.h index 3e89396a93..1995e11ba8 100644 --- a/src/utils/ucc_rcache.h +++ b/src/utils/ucc_rcache.h @@ -63,7 +63,7 @@ ucc_rcache_get(ucc_rcache_t *rcache, void *address, size_t length, ucs_status_t status; #ifdef UCS_HAVE_RCACHE_REGION_ALIGNMENT - status = ucs_rcache_get(rcache, address, length, ucc_get_page_size(), + status = ucs_rcache_get(rcache, address, length, UCS_PGT_ADDR_ALIGN, PROT_READ | PROT_WRITE, arg, region_p); #else status = ucs_rcache_get(rcache, address, length,