Skip to content

Commit

Permalink
UTIL: Set rcache alignment to UCS_PGT_ADDR_ALIGN
Browse files Browse the repository at this point in the history
  • Loading branch information
x41lakazam committed Aug 19, 2024
1 parent 37adf47 commit d7cd5a3
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 36 deletions.
87 changes: 52 additions & 35 deletions src/components/tl/mlx5/tl_mlx5_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,9 @@ UCC_CLASS_INIT_FUNC(ucc_tl_mlx5_team_t, ucc_base_context_t *tl_context,

self->local_mcast_team_ready = 0;
if (ctx->mcast.mcast_ctx_ready) {
status = ucc_tl_mlx5_mcast_team_init(tl_context, &(self->mcast), &(ctx->mcast),
params, &(UCC_TL_MLX5_TEAM_LIB(self)->cfg.mcast_conf));
status = ucc_tl_mlx5_mcast_team_init(
tl_context, &(self->mcast), &(ctx->mcast), params,
&(UCC_TL_MLX5_TEAM_LIB(self)->cfg.mcast_conf));
if (UCC_OK != status) {
tl_warn(tl_context->lib, "mcast team init failed");
} else {
Expand Down Expand Up @@ -119,7 +120,7 @@ ucc_status_t ucc_tl_mlx5_team_destroy(ucc_base_team_t *tl_team)

static inline ucc_status_t ucc_tl_mlx5_a2a_team_test(ucc_base_team_t *team)
{
ucc_tl_mlx5_team_t *tl_team = ucc_derived_of(team, ucc_tl_mlx5_team_t);
ucc_tl_mlx5_team_t *tl_team = ucc_derived_of(team, ucc_tl_mlx5_team_t);

switch (tl_team->a2a_state) {
case TL_MLX5_TEAM_STATE_ALLTOALL_INIT:
Expand Down Expand Up @@ -154,16 +155,15 @@ static inline ucc_status_t ucc_tl_mlx5_a2a_team_test(ucc_base_team_t *team)

ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team)
{
ucc_tl_mlx5_team_t *tl_team = ucc_derived_of(team, ucc_tl_mlx5_team_t);
ucc_team_t *core_team = UCC_TL_CORE_TEAM(tl_team);
ucc_subset_t subset = {.map = UCC_TL_TEAM_MAP(tl_team),
.myrank = UCC_TL_TEAM_RANK(tl_team)};
ucc_status_t a2a_status = UCC_OK;
ucc_status_t mcast_status = UCC_OK;
ucc_tl_mlx5_mcast_coll_comm_t *comm = NULL;
ucc_tl_mlx5_team_t *tl_team = ucc_derived_of(team, ucc_tl_mlx5_team_t);
ucc_team_t *core_team = UCC_TL_CORE_TEAM(tl_team);
ucc_subset_t subset = {.map = UCC_TL_TEAM_MAP(tl_team),
.myrank = UCC_TL_TEAM_RANK(tl_team)};
ucc_status_t a2a_status = UCC_OK;
ucc_status_t mcast_status = UCC_OK;
ucc_tl_mlx5_mcast_coll_comm_t *comm = NULL;
ucc_status_t status;


if (tl_team->global_sync_req != NULL) {
status = ucc_service_coll_test(tl_team->global_sync_req);
if (status < 0) {
Expand All @@ -181,15 +181,18 @@ ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team)
tl_team->global_sync_req = NULL;

if (tl_team->mcast_state == TL_MLX5_TEAM_STATE_MCAST_CTX_CHECK &&
tl_team->a2a_state == TL_MLX5_TEAM_STATE_ALLTOALL_CTX_CHECK ) {
tl_team->a2a_status.global = tl_team->global_status_array[UCC_TL_MLX5_A2A_STATUS_INDEX];
tl_team->a2a_state = TL_MLX5_TEAM_STATE_ALLTOALL_INIT;
tl_team->a2a_state == TL_MLX5_TEAM_STATE_ALLTOALL_CTX_CHECK) {
tl_team->a2a_status.global =
tl_team->global_status_array[UCC_TL_MLX5_A2A_STATUS_INDEX];
tl_team->a2a_state = TL_MLX5_TEAM_STATE_ALLTOALL_INIT;

if (tl_team->global_status_array[UCC_TL_MLX5_MCAST_STATUS_INDEX] != UCC_OK) {
if (tl_team->global_status_array[UCC_TL_MLX5_MCAST_STATUS_INDEX] !=
UCC_OK) {
/* mcast context is not available for some of the team members so we cannot create
* mcast team */
tl_debug(UCC_TL_TEAM_LIB(tl_team),
"some of the ranks do not have mcast context available so no mcast team is created");
"some of the ranks do not have mcast context "
"available so no mcast team is created");

if (tl_team->local_mcast_team_ready) {
comm = tl_team->mcast->mcast_comm;
Expand All @@ -215,7 +218,8 @@ ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team)

return UCC_INPROGRESS;
} else {
if (tl_team->global_status_array[UCC_TL_MLX5_A2A_STATUS_INDEX] != UCC_OK) {
if (tl_team->global_status_array[UCC_TL_MLX5_A2A_STATUS_INDEX] !=
UCC_OK) {
//a2a team not avail for some of nodes so disable it
if (tl_team->a2a_state == TL_MLX5_TEAM_STATE_ALLTOALL_READY) {
// free the resources
Expand All @@ -224,7 +228,8 @@ ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team)
tl_team->a2a_state = TL_MLX5_TEAM_STATE_ALLTOALL_NOT_AVAILABLE;
}

if (tl_team->global_status_array[UCC_TL_MLX5_MCAST_STATUS_INDEX] != UCC_OK) {
if (tl_team->global_status_array[UCC_TL_MLX5_MCAST_STATUS_INDEX] !=
UCC_OK) {
//mcast team not avail for some of nodes so disable it
if (tl_team->mcast_state == TL_MLX5_TEAM_STATE_MCAST_READY) {
// free the resources
Expand All @@ -233,16 +238,22 @@ ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team)
tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_NOT_AVAILABLE;
}

tl_debug(team->context->lib, "team %p: MCAST component is %s ALLTOALL component is %s",
team, (tl_team->mcast_state == TL_MLX5_TEAM_STATE_MCAST_READY)?"ENABLED":"DISABLED",
(tl_team->a2a_state == TL_MLX5_TEAM_STATE_ALLTOALL_READY)?"ENABLED":"DISABLED");
tl_debug(team->context->lib,
"team %p: MCAST component is %s ALLTOALL component is %s",
team,
(tl_team->mcast_state == TL_MLX5_TEAM_STATE_MCAST_READY)
? "ENABLED"
: "DISABLED",
(tl_team->a2a_state == TL_MLX5_TEAM_STATE_ALLTOALL_READY)
? "ENABLED"
: "DISABLED");
}

return UCC_OK;
}

ucc_assert(tl_team->global_sync_req == NULL);

if (tl_team->mcast_state == TL_MLX5_TEAM_STATE_MCAST_CTX_CHECK &&
tl_team->a2a_state == TL_MLX5_TEAM_STATE_ALLTOALL_CTX_CHECK) {
// check if ctx is ready for a2a and mcast
Expand Down Expand Up @@ -274,37 +285,41 @@ ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team)
}

tl_team->local_status_array[UCC_TL_MLX5_A2A_STATUS_INDEX] =
(tl_team->a2a_state == TL_MLX5_TEAM_STATE_ALLTOALL_READY) ? UCC_OK : UCC_ERR_NO_RESOURCE;
(tl_team->a2a_state == TL_MLX5_TEAM_STATE_ALLTOALL_READY)
? UCC_OK
: UCC_ERR_NO_RESOURCE;
tl_team->local_status_array[UCC_TL_MLX5_MCAST_STATUS_INDEX] =
(tl_team->mcast_state == TL_MLX5_TEAM_STATE_MCAST_READY) ? UCC_OK : UCC_ERR_NO_RESOURCE;
(tl_team->mcast_state == TL_MLX5_TEAM_STATE_MCAST_READY)
? UCC_OK
: UCC_ERR_NO_RESOURCE;

tl_debug(UCC_TL_TEAM_LIB(tl_team),
"posting global status, local status: ALLTOALL %d MCAST %d",
(tl_team->a2a_state == TL_MLX5_TEAM_STATE_ALLTOALL_READY),
(tl_team->mcast_state == TL_MLX5_TEAM_STATE_MCAST_READY));

initial_sync_post:
status = ucc_service_allreduce(
core_team, tl_team->local_status_array, tl_team->global_status_array,
UCC_DT_INT32, UCC_TL_MLX5_FEATURES_COUNT, UCC_OP_MIN, subset, &tl_team->global_sync_req);
status = ucc_service_allreduce(core_team, tl_team->local_status_array,
tl_team->global_status_array, UCC_DT_INT32,
UCC_TL_MLX5_FEATURES_COUNT, UCC_OP_MIN,
subset, &tl_team->global_sync_req);
if (status < 0) {
tl_debug(UCC_TL_TEAM_LIB(tl_team),
"failed to collect global status");
tl_debug(UCC_TL_TEAM_LIB(tl_team), "failed to collect global status");
return status;
}

return UCC_INPROGRESS;
}

ucc_status_t ucc_tl_mlx5_team_get_scores(ucc_base_team_t * tl_team,
ucc_status_t ucc_tl_mlx5_team_get_scores(ucc_base_team_t *tl_team,
ucc_coll_score_t **score_p)
{
ucc_tl_mlx5_team_t *team = ucc_derived_of(tl_team, ucc_tl_mlx5_team_t);
ucc_base_context_t *ctx = UCC_TL_TEAM_CTX(team);
ucc_base_lib_t *lib = UCC_TL_TEAM_LIB(team);
ucc_memory_type_t mt[2] = {UCC_MEMORY_TYPE_HOST, UCC_MEMORY_TYPE_CUDA};
ucc_coll_score_t *score;
ucc_status_t status;
ucc_coll_score_t *score;
ucc_status_t status;
ucc_coll_score_team_info_t team_info;

team_info.alg_fn = NULL;
Expand All @@ -313,9 +328,11 @@ ucc_status_t ucc_tl_mlx5_team_get_scores(ucc_base_team_t * tl_team,
team_info.num_mem_types = 2;
team_info.supported_mem_types = mt;
team_info.supported_colls =
(UCC_COLL_TYPE_ALLTOALL * (team->a2a_state == TL_MLX5_TEAM_STATE_ALLTOALL_READY) * 0) |
UCC_COLL_TYPE_BCAST * (team->mcast_state == TL_MLX5_TEAM_STATE_MCAST_READY);
team_info.size = UCC_TL_TEAM_SIZE(team);
(UCC_COLL_TYPE_ALLTOALL *
(team->a2a_state == TL_MLX5_TEAM_STATE_ALLTOALL_READY)) |
UCC_COLL_TYPE_BCAST *
(team->mcast_state == TL_MLX5_TEAM_STATE_MCAST_READY);
team_info.size = UCC_TL_TEAM_SIZE(team);

status = ucc_coll_score_build_default(
tl_team, UCC_TL_MLX5_DEFAULT_SCORE, ucc_tl_mlx5_coll_init,
Expand Down
2 changes: 1 addition & 1 deletion src/utils/ucc_rcache.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ ucc_rcache_get(ucc_rcache_t *rcache, void *address, size_t length,
ucs_status_t status;

#ifdef UCS_HAVE_RCACHE_REGION_ALIGNMENT
status = ucs_rcache_get(rcache, address, length, ucc_get_page_size(),
status = ucs_rcache_get(rcache, address, length, UCS_PGT_ADDR_ALIGN,
PROT_READ | PROT_WRITE, arg, region_p);
#else
status = ucs_rcache_get(rcache, address, length,
Expand Down

0 comments on commit d7cd5a3

Please sign in to comment.