Skip to content

Commit

Permalink
add group info struct
Browse files Browse the repository at this point in the history
  • Loading branch information
MamziB committed Feb 4, 2024
1 parent 7fe2eb1 commit 473f4f7
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 55 deletions.
8 changes: 7 additions & 1 deletion src/components/tl/mlx5/mcast/tl_mlx5_mcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,12 @@ typedef struct ucc_tl_mlx5_mcast_coll_context {
ucc_base_lib_t *lib;
} ucc_tl_mlx5_mcast_coll_context_t;

typedef struct ucc_tl_mlx5_mcast_join_info_t {
ucc_status_t status;
uint16_t dlid;
union ibv_gid dgid;
} ucc_tl_mlx5_mcast_join_info_t;

typedef struct ucc_tl_mlx5_mcast_context {
ucc_thread_mode_t tm;
ucc_tl_mlx5_mcast_coll_context_t mcast_context;
Expand Down Expand Up @@ -228,7 +234,7 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm {
int n_prep_reliable;
int n_mcast_reliable;
int wsize;
void *group_setup_info;
ucc_tl_mlx5_mcast_join_info_t *group_setup_info;
ucc_service_coll_req_t *group_setup_info_req;
ucc_status_t (*bcast_post) (void*, void*, size_t, ucc_rank_t, ucc_service_coll_req_t**);
ucc_status_t (*bcast_test) (ucc_service_coll_req_t*);
Expand Down
79 changes: 25 additions & 54 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ ucc_status_t ucc_tl_mlx5_mcast_team_test(ucc_base_team_t *team)
ucc_tl_mlx5_mcast_coll_comm_t *comm = tl_team->mcast->mcast_comm;
ucc_status_t status = UCC_OK;
struct sockaddr_in6 net_addr = {0,};
void *data = NULL;
ucc_tl_mlx5_mcast_join_info_t *data = NULL;
struct rdma_cm_event tmp;
size_t mgid_s;
size_t mlid_s;
Expand Down Expand Up @@ -357,60 +357,33 @@ ucc_status_t ucc_tl_mlx5_mcast_team_test(ucc_base_team_t *team)
}

case TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_READY:
case TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_FAILED:
{
/* now rank 0 bcast the lid/gid to other processes */
data = ucc_malloc(sizeof(ucc_status_t) + mgid_s + mlid_s);

data = ucc_calloc(1, sizeof(ucc_tl_mlx5_mcast_join_info_t),
"ucc_tl_mlx5_mcast_join_info_t");
if (!data) {
tl_error(comm->lib, "unable to allocate memory for group setup info");
return UCC_ERR_NO_MEMORY;
}

comm->group_setup_info = data;
memset(data, 0, sizeof(ucc_status_t) + mgid_s + mlid_s);
memcpy((PTR_OFFSET(data, sizeof(ucc_status_t))),
&comm->event->param.ud.ah_attr.grh.dgid, mgid_s);
memcpy((PTR_OFFSET(data, sizeof(ucc_status_t) + mgid_s)),
&comm->event->param.ud.ah_attr.dlid, mlid_s);
memcpy(data, &status, sizeof(ucc_status_t));

status = comm->bcast_post(comm->p2p_ctx, data, sizeof(ucc_status_t) + mgid_s +
mlid_s, 0, &comm->group_setup_info_req);
if (UCC_OK != status) {
tl_error(comm->lib, "unable to post bcast for group setup info");
ucc_free(comm->group_setup_info);
if (comm->event) {
rdma_ack_cm_event(comm->event);
comm->event = NULL;
}
return status;
}

comm->mcast_lid = *((uint16_t*)(PTR_OFFSET(data, sizeof(ucc_status_t) +
mgid_s)));
comm->mcast_addr = net_addr;

memcpy((void*)&comm->mgid, PTR_OFFSET(data, sizeof(ucc_status_t)),
mgid_s);

tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_GRP_BCAST_POST;

return UCC_INPROGRESS;
}
case TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_FAILED:
{
/* now rank 0 bcast the failed status to other processes so others do not hang */
data = ucc_malloc(sizeof(ucc_status_t) + mgid_s + mlid_s);
if (!data) {
tl_error(comm->lib, "unable to allocate memory for group setup info");
return UCC_ERR_NO_MEMORY;
if (tl_team->mcast_state == TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_READY) {
/* rank 0 bcast the lid/gid to other processes */
data->status = status;
data->dgid = comm->event->param.ud.ah_attr.grh.dgid;
data->dlid = comm->event->param.ud.ah_attr.dlid;
comm->mcast_lid = data->dlid;
comm->mgid = data->dgid;
comm->mcast_addr = net_addr;
} else {
/* rank 0 bcast the failed status to other processes so others do not hang */
data->status = UCC_ERR_NO_RESOURCE;
}

comm->group_setup_info = data;
status = UCC_ERR_NO_RESOURCE;
memset(data, 0, sizeof(ucc_status_t) + mgid_s + mlid_s);
memcpy(data, &status, sizeof(ucc_status_t));
status = comm->bcast_post(comm->p2p_ctx, data, sizeof(ucc_status_t) + mgid_s +
mlid_s, 0, &comm->group_setup_info_req);
status = comm->bcast_post(comm->p2p_ctx, data, sizeof(ucc_tl_mlx5_mcast_join_info_t),
0, &comm->group_setup_info_req);
if (UCC_OK != status) {
tl_error(comm->lib, "unable to post bcast for group setup info");
ucc_free(comm->group_setup_info);
Expand All @@ -435,7 +408,7 @@ ucc_status_t ucc_tl_mlx5_mcast_team_test(ucc_base_team_t *team)
return status;
}

if (*((ucc_status_t *)(comm->group_setup_info)) != UCC_OK) {
if (comm->group_setup_info->status != UCC_OK) {
/* rank 0 was not able to join a mcast group so all
* the ranks should return */
ucc_free(comm->group_setup_info);
Expand Down Expand Up @@ -479,7 +452,8 @@ ucc_status_t ucc_tl_mlx5_mcast_team_test(ucc_base_team_t *team)
{
/* none 0 ranks bcast post to wait for rank 0 for lid/gid
* of the mcast group */
data = ucc_malloc(sizeof(ucc_status_t) + mgid_s + mlid_s);
data = ucc_calloc(1, sizeof(ucc_tl_mlx5_mcast_join_info_t),
"ucc_tl_mlx5_mcast_join_info_t");
if (!data) {
tl_error(comm->lib, "unable to allocate memory for group setup info");
return UCC_ERR_NO_MEMORY;
Expand Down Expand Up @@ -509,7 +483,7 @@ ucc_status_t ucc_tl_mlx5_mcast_team_test(ucc_base_team_t *team)
}

data = comm->group_setup_info;
status = *(ucc_status_t *) data;
status = data->status;
if (UCC_OK != status) {
/* rank 0 was not able to join a mcast group so all
* the ranks should return */
Expand All @@ -518,7 +492,7 @@ ucc_status_t ucc_tl_mlx5_mcast_team_test(ucc_base_team_t *team)
}

/* now it is time for none rank 0 to call rdma_join_multicast() */
memcpy(&net_addr.sin6_addr, PTR_OFFSET(data, sizeof(ucc_status_t)), sizeof(struct in6_addr));
memcpy(&net_addr.sin6_addr, &(data->dgid), sizeof(struct in6_addr));
net_addr.sin6_family = AF_INET6;

status = ucc_tl_mlx5_mcast_join_mcast_post(comm->ctx, &net_addr, 0);
Expand Down Expand Up @@ -548,13 +522,10 @@ ucc_status_t ucc_tl_mlx5_mcast_team_test(ucc_base_team_t *team)

ucc_assert(comm->event != NULL);

comm->mcast_lid = *((uint16_t*)(PTR_OFFSET(comm->group_setup_info,
sizeof(ucc_status_t) + mgid_s)));
comm->mcast_lid = comm->group_setup_info->dlid;
comm->mgid = comm->group_setup_info->dgid;
comm->mcast_addr = net_addr;

memcpy((void*)&comm->mgid, PTR_OFFSET(comm->group_setup_info,
sizeof(ucc_status_t)), mgid_s);

ucc_free(comm->group_setup_info);
if (comm->event) {
rdma_ack_cm_event(comm->event);
Expand Down

0 comments on commit 473f4f7

Please sign in to comment.