Skip to content

Commit

Permalink
TL/MLX5: mcast multi-group support part 1
Browse files Browse the repository at this point in the history
  • Loading branch information
MamziB committed Dec 23, 2024
1 parent 73651ea commit 79512f4
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 83 deletions.
10 changes: 4 additions & 6 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,6 @@ struct pp_packet {
};

struct mcast_ctx {
struct ibv_qp *qp;
struct ibv_ah *ah;
struct ibv_send_wr swr;
struct ibv_sge ssg;

Expand Down Expand Up @@ -310,8 +308,8 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm {
ucc_rank_t commsize;
char *grh_buf;
struct ibv_mr *grh_mr;
uint16_t mcast_lid;
union ibv_gid mgid;
uint16_t *lid_list;
union ibv_gid *mgid_list;
unsigned max_inline;
size_t max_eager;
int max_per_packet;
Expand All @@ -334,7 +332,7 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm {
int comm_id;
void *p2p_ctx;
ucc_base_lib_t *lib;
struct sockaddr_in6 mcast_addr;
struct sockaddr_in6 *mcast_addr_list;
int cuda_mem_enabled;
ucc_tl_mlx5_mcast_join_info_t *group_setup_info;
ucc_service_coll_req_t *group_setup_info_req;
Expand Down Expand Up @@ -490,7 +488,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_post_recv_buffers(ucc_tl_mlx5_mcast
}
if (i != 0) {
rwr[i-1].next = NULL;
if (ibv_post_recv(comm->mcast.qp, &rwr[0], &bad_wr)) {
if (ibv_post_recv(comm->mcast.qp_list[0], &rwr[0], &bad_wr)) {
tl_error(comm->lib, "failed to prepost recvs: errno %d", errno);
return UCC_ERR_NO_RESOURCE;
}
Expand Down
201 changes: 137 additions & 64 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,19 @@ ucc_status_t ucc_tl_mlx5_setup_mcast_group_join_post(ucc_tl_mlx5_mcast_coll_comm
ucc_status_t ucc_tl_mlx5_mcast_init_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx,
ucc_tl_mlx5_mcast_coll_comm_t *comm)
{
struct ibv_qp_init_attr qp_init_attr = {0};
int max_inline = INT_MAX;
struct ibv_qp_init_attr qp_init_attr = {0};
int i;
int j;

comm->mcast.qp_list = ucc_malloc(comm->mcast_group_count * sizeof(struct ibv_qp *), "ibv_qp* list");
if (!comm->mcast.qp_list) {
tl_error(ctx->lib, "failed to allocate memory for ibv_qp*");
return UCC_ERR_NO_MEMORY;
}

qp_init_attr.qp_type = IBV_QPT_UD;
qp_init_attr.send_cq = comm->scq;
qp_init_attr.send_cq = comm->scq; //cq can be shared between multiple QPs
qp_init_attr.recv_cq = comm->rcq;
qp_init_attr.sq_sig_all = 0;
qp_init_attr.cap.max_send_wr = comm->params.sx_depth;
Expand All @@ -294,41 +303,78 @@ ucc_status_t ucc_tl_mlx5_mcast_init_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx,
qp_init_attr.cap.max_send_sge = comm->params.sx_sge;
qp_init_attr.cap.max_recv_sge = comm->params.rx_sge;

comm->mcast.qp = ibv_create_qp(ctx->pd, &qp_init_attr);
if (!comm->mcast.qp) {
tl_warn(ctx->lib, "failed to create mcast qp, errno %d", errno);
return UCC_ERR_NO_RESOURCE;
for (i = 0; i < comm->mcast_group_count; i++) {
comm->mcast.qp_list[i] = ibv_create_qp(ctx->pd, &qp_init_attr);
if (!comm->mcast.qp_list[i]) {
tl_error(ctx->lib, "Failed to create mcast UD qp index %d, errno %d", i, errno);
goto error;
}
if (qp_init_attr.cap.max_inline_data < max_inline) {
max_inline = qp_init_attr.cap.max_inline_data;
}
}

if (comm->cuda_mem_enabled) {
/* max inline send otherwise it segfault during ibv send */
comm->max_inline = 0;
} else {
comm->max_inline = qp_init_attr.cap.max_inline_data;
comm->max_inline = max_inline;
}

return UCC_OK;

error:
for (j = 0; j < i; j++) {
ibv_destroy_qp(comm->mcast.qp_list[j]);
}
ucc_free(comm->mcast.qp_list);
comm->mcast.qp_list = NULL;

return UCC_ERR_NO_RESOURCE;
}

static ucc_status_t ucc_tl_mlx5_mcast_create_ah(ucc_tl_mlx5_mcast_coll_comm_t *comm)
{
int i, j, ret;
struct ibv_ah_attr ah_attr = {
.is_global = 1,
.grh = {.sgid_index = 0},
.dlid = comm->mcast_lid,
.sl = DEF_SL,
.src_path_bits = DEF_SRC_PATH_BITS,
.port_num = comm->ctx->ib_port
};

memcpy(ah_attr.grh.dgid.raw, &comm->mgid, sizeof(ah_attr.grh.dgid.raw));
comm->mcast.ah_list = ucc_malloc(comm->mcast_group_count * sizeof(struct ibv_ah *), "ibv_ah array");
if (!comm->mcast.ah_list) {
tl_error(comm->lib, "failed to allocate memory for mcast address handle of size %lu",
comm->mcast_group_count * sizeof(struct ibv_ah *));
return UCC_ERR_NO_MEMORY;
}

comm->mcast.ah = ibv_create_ah(comm->ctx->pd, &ah_attr);
if (!comm->mcast.ah) {
tl_warn(comm->lib, "failed to create AH");
return UCC_ERR_NO_RESOURCE;
for (i = 0; i < comm->mcast_group_count; i ++) {
ah_attr.dlid = comm->lid_list[i];
memcpy(ah_attr.grh.dgid.raw, &comm->mgid_list[i], sizeof(ah_attr.grh.dgid.raw));

comm->mcast.ah_list[i] = ibv_create_ah(comm->ctx->pd, &ah_attr);
if (!comm->mcast.ah_list[i]) {
tl_error(comm->lib, "failed to create AH index %d", i);
goto error;
}
}

return UCC_OK;

error:
for (j = 0; j < i; j++) {
ret = ibv_destroy_ah(comm->mcast.ah_list[j]);
if (ret) {
tl_error(comm->lib, "couldn't destroy ah");
return UCC_ERR_NO_RESOURCE;
}
}
ucc_free(comm->mcast.ah_list);
comm->mcast.ah_list = NULL;
return UCC_ERR_NO_RESOURCE;
}

ucc_status_t ucc_tl_mlx5_mcast_setup_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx,
Expand All @@ -337,16 +383,15 @@ ucc_status_t ucc_tl_mlx5_mcast_setup_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx,
struct ibv_port_attr port_attr;
struct ibv_qp_attr attr;
uint16_t pkey;
int i;

ibv_query_port(ctx->ctx, ctx->ib_port, &port_attr);

for (ctx->pkey_index = 0; ctx->pkey_index < port_attr.pkey_tbl_len;
++ctx->pkey_index) {
ibv_query_pkey(ctx->ctx, ctx->ib_port, ctx->pkey_index, &pkey);
if (pkey == DEF_PKEY)
break;
}

if (ctx->pkey_index >= port_attr.pkey_tbl_len) {
ctx->pkey_index = 0;
ibv_query_pkey(ctx->ctx, ctx->ib_port, ctx->pkey_index, &pkey);
Expand All @@ -359,43 +404,53 @@ ucc_status_t ucc_tl_mlx5_mcast_setup_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx,
"index 0 pkey:0x%04x", DEF_PKEY, ctx->ib_port, pkey);
}

attr.qp_state = IBV_QPS_INIT;
attr.pkey_index = ctx->pkey_index;
attr.port_num = ctx->ib_port;
attr.qkey = DEF_QKEY;
for (i = 0; i < comm->mcast_group_count; i++) {
attr.qp_state = IBV_QPS_INIT;
attr.pkey_index = ctx->pkey_index;
attr.port_num = ctx->ib_port;
attr.qkey = DEF_QKEY;

if (ibv_modify_qp(comm->mcast.qp, &attr,
IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_QKEY)) {
tl_warn(ctx->lib, "failed to move mcast qp to INIT, errno %d", errno);
return UCC_ERR_NO_RESOURCE;
}
if (ibv_modify_qp(comm->mcast.qp_list[i], &attr,
IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_QKEY)) {
tl_error(ctx->lib, "failed to move mcast qp to INIT, errno %d", errno);
goto error;
}

if (ibv_attach_mcast(comm->mcast.qp, &comm->mgid, comm->mcast_lid)) {
tl_warn(ctx->lib, "failed to attach QP to the mcast group, errno %d", errno);
return UCC_ERR_NO_RESOURCE;
}
if (ibv_attach_mcast(comm->mcast.qp_list[i], &comm->mgid_list[i], comm->lid_list[i])) {
tl_error(ctx->lib, "failed to attach QP to the mcast group with mcast_lid %d , errno %d", errno, comm->lid_list[i]);
goto error;
}

/* Ok, now cycle to RTR on everyone */
attr.qp_state = IBV_QPS_RTR;
if (ibv_modify_qp(comm->mcast.qp, &attr, IBV_QP_STATE)) {
tl_warn(ctx->lib, "failed to modify QP to RTR, errno %d", errno);
return UCC_ERR_NO_RESOURCE;
}
attr.qp_state = IBV_QPS_RTR;
if (ibv_modify_qp(comm->mcast.qp_list[i], &attr, IBV_QP_STATE)) {
tl_error(ctx->lib, "failed to modify QP to RTR, errno %d", errno);
goto error;
}

attr.qp_state = IBV_QPS_RTS;
attr.sq_psn = DEF_PSN;
if (ibv_modify_qp(comm->mcast.qp, &attr, IBV_QP_STATE | IBV_QP_SQ_PSN)) {
tl_warn(ctx->lib, "failed to modify QP to RTS, errno %d", errno);
return UCC_ERR_NO_RESOURCE;
attr.qp_state = IBV_QPS_RTS;
attr.sq_psn = DEF_PSN;
if (ibv_modify_qp(comm->mcast.qp_list[i], &attr, IBV_QP_STATE | IBV_QP_SQ_PSN)) {
tl_error(ctx->lib, "failed to modify QP to RTS, errno %d", errno);
goto error;
}
}

/* Create the address handle */
/* create the address handle */
if (UCC_OK != ucc_tl_mlx5_mcast_create_ah(comm)) {
tl_warn(ctx->lib, "failed to create adress handle");
return UCC_ERR_NO_RESOURCE;
goto error;
}

return UCC_OK;

error:
for (i=0; i < comm->mcast_group_count; i++) {
ibv_destroy_qp(comm->mcast.qp_list[i]);
}
ucc_free(comm->mcast.qp_list);
comm->mcast.qp_list = NULL;

return UCC_ERR_NO_RESOURCE;
}

ucc_status_t ucc_tl_mlx5_mcast_create_rc_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx,
Expand Down Expand Up @@ -538,15 +593,15 @@ ucc_status_t ucc_tl_mlx5_fini_mcast_group(ucc_tl_mlx5_mcast_coll_context_t *ctx,
char buf[40];
const char *dst;

dst = inet_ntop(AF_INET6, &comm->mcast_addr, buf, 40);
dst = inet_ntop(AF_INET6, &comm->mcast_addr_list[0], buf, 40);
if (NULL == dst) {
tl_error(comm->lib, "inet_ntop failed");
return UCC_ERR_NO_RESOURCE;
}

tl_debug(ctx->lib, "mcast leave: ctx %p, comm %p, dgid: %s", ctx, comm, buf);

if (rdma_leave_multicast(ctx->id, (struct sockaddr*)&comm->mcast_addr)) {
if (rdma_leave_multicast(ctx->id, (struct sockaddr*)&comm->mcast_addr_list[0])) {
tl_error(comm->lib, "mcast rmda_leave_multicast failed");
return UCC_ERR_NO_RESOURCE;
}
Expand All @@ -559,11 +614,10 @@ ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm)
ucc_tl_mlx5_mcast_context_t *mcast_ctx = ucc_container_of(comm->ctx, ucc_tl_mlx5_mcast_context_t, mcast_context);
ucc_tl_mlx5_context_t *mlx5_ctx = ucc_container_of(mcast_ctx, ucc_tl_mlx5_context_t, mcast);
ucc_context_h context = mlx5_ctx->super.super.ucc_context;
int ret;
int ret, i;
ucc_status_t status;

tl_debug(comm->lib, "cleaning mcast comm: %p, id %d, mlid %x",
comm, comm->comm_id, comm->mcast_lid);
tl_debug(comm->lib, "cleaning mcast comm: %p, id %d", comm, comm->comm_id);

while (UCC_INPROGRESS == (status = ucc_tl_mlx5_mcast_reliable(comm))) {
ucc_context_progress(context);
Expand All @@ -575,20 +629,26 @@ ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm)
return status;
}

if (comm->mcast.qp) {
ret = ibv_detach_mcast(comm->mcast.qp, &comm->mgid, comm->mcast_lid);
if (ret) {
tl_error(comm->lib, "couldn't detach QP, ret %d, errno %d", ret, errno);
return UCC_ERR_NO_RESOURCE;
}
}
if (comm->mcast.qp_list) {
for (i = 0; i < comm->mcast_group_count; i++) {
if (comm->mcast.qp_list[i]) {
ret = ibv_detach_mcast(comm->mcast.qp_list[i], &(comm->mgid_list[i]), comm->lid_list[i]);
if (ret) {
tl_error(comm->lib, "couldn't detach QP, ret %d, errno %d", ret, errno);
return UCC_ERR_NO_RESOURCE;
}

if (comm->mcast.qp) {
ret = ibv_destroy_qp(comm->mcast.qp);
if (ret) {
tl_error(comm->lib, "failed to destroy QP %d", ret);
return UCC_ERR_NO_RESOURCE;
ret = ibv_destroy_qp(comm->mcast.qp_list[i]);
if (ret) {
tl_error(comm->lib, "failed to destroy QP %d", ret);
return UCC_ERR_NO_RESOURCE;
}

comm->mcast.qp_list[i] = NULL;
}
}
ucc_free(comm->mcast.qp_list);
comm->mcast.qp_list = NULL;
}

if (comm->rcq) {
Expand Down Expand Up @@ -643,20 +703,33 @@ ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm)
ucc_free(comm->call_rsgs);
}

if (comm->mcast.ah) {
ret = ibv_destroy_ah(comm->mcast.ah);
if (ret) {
tl_error(comm->lib, "couldn't destroy ah");
return UCC_ERR_NO_RESOURCE;
if (comm->mcast.ah_list) {
for (i = 0; i < comm->mcast_group_count; i++) {
if (comm->mcast.ah_list[i]) {
ret = ibv_destroy_ah(comm->mcast.ah_list[i]);
if (ret) {
tl_error(comm->lib, "couldn't destroy ah");
return UCC_ERR_NO_RESOURCE;
}
comm->mcast.ah_list[i] = NULL;
}
}
ucc_free(comm->mcast.ah_list);
comm->mcast.ah_list = NULL;
}

if (comm->mcast_lid) {
if (comm->lid_list) {
status = ucc_tl_mlx5_fini_mcast_group(comm->ctx, comm);
if (status) {
tl_error(comm->lib, "couldn't leave mcast group");
return status;
}
ucc_free(comm->lid_list);
ucc_free(comm->mgid_list);
ucc_free(comm->mcast_addr_list);
comm->lid_list = NULL;
comm->lid_list = NULL;
comm->mcast_addr_list = NULL;
}

if (comm->ctx->params.print_nack_stats) {
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_send(ucc_tl_mlx5_mcast_coll_comm_t
tl_trace(comm->lib, "post_send, psn %d, length %d, zcopy %d, signaled %d",
pp->psn, pp->length, zcopy, swr[0].send_flags & IBV_SEND_SIGNALED);

if (0 != (rc = ibv_post_send(comm->mcast.qp, &swr[0], &bad_wr))) {
if (0 != (rc = ibv_post_send(comm->mcast.qp_list[0], &swr[0], &bad_wr))) {
tl_error(comm->lib, "post send failed: ret %d, start_psn %d, to_send %d, "
"to_recv %d, length %d, psn %d, inline %d",
rc, req->start_psn, req->to_send, req->to_recv,
Expand Down
Loading

0 comments on commit 79512f4

Please sign in to comment.