From 622a09e18b356f034197b3bdff898f2b736440b5 Mon Sep 17 00:00:00 2001 From: Mamzi Bayatpour Date: Wed, 15 Nov 2023 11:14:42 -0800 Subject: [PATCH] TL/MLX5: adding mcast helper functions --- .../tl/mlx5/mcast/tl_mlx5_mcast_context.c | 236 +++++++++++- .../tl/mlx5/mcast/tl_mlx5_mcast_helper.c | 364 ++++++++++++++++++ .../tl/mlx5/mcast/tl_mlx5_mcast_team.c | 110 ++++++ 3 files changed, 708 insertions(+), 2 deletions(-) diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_context.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_context.c index 90014d1400..7d932a8293 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_context.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_context.c @@ -10,9 +10,241 @@ #include #include "core/ucc_service_coll.h" #include "tl_mlx5.h" +#include "tl_mlx5_mcast_helper.h" +#include "tl_mlx5_mcast_rcache.h" -ucc_status_t ucc_tl_mlx5_mcast_context_init(ucc_tl_mlx5_mcast_context_t *context, /* NOLINT */ - ucc_tl_mlx5_mcast_ctx_params_t *mcast_ctx_conf /* NOLINT */) +#define UCC_TL_MLX5_MCAST_MAX_MTU_COUNT 5 +int mtu_lookup[UCC_TL_MLX5_MCAST_MAX_MTU_COUNT][2] = + {{256, IBV_MTU_256}, {512, IBV_MTU_512}, + {1024, IBV_MTU_1024}, {2048, IBV_MTU_2048}, + {4096, IBV_MTU_4096}}; + +ucc_status_t ucc_tl_mlx5_mcast_context_init(ucc_tl_mlx5_mcast_context_t *context, + ucc_tl_mlx5_mcast_ctx_params_t *mcast_ctx_conf) { + ucc_status_t status = UCC_OK; + struct ibv_device **device_list = NULL; + struct ibv_device *dev = NULL; + char *devname = NULL; + int is_ipv4 = 0; + struct sockaddr_in *in_src_addr = NULL; + struct rdma_cm_event *revent = NULL; + char *ib = NULL; + char *ib_name = NULL; + char *port = NULL; + int active_mtu = 4096; + int max_mtu = 4096; + ucc_tl_mlx5_mcast_coll_context_t *ctx = NULL; + + struct ibv_port_attr port_attr; + struct ibv_device_attr device_attr; + struct sockaddr_storage ip_oib_addr; + struct sockaddr_storage dst_addr; + int num_devices; + char addrstr[128]; + ucc_tl_mlx5_context_t *mlx5_ctx; + ucc_base_lib_t *lib; + int i; + int user_provided_ib; + int ib_valid; + const char *dst; + + ctx = &(context->mcast_context); + memset(ctx, 0, sizeof(ucc_tl_mlx5_mcast_coll_context_t)); + memcpy(&ctx->params, mcast_ctx_conf, sizeof(ucc_tl_mlx5_mcast_ctx_params_t)); + + mlx5_ctx = ucc_container_of(context, ucc_tl_mlx5_context_t, mcast); + lib = mlx5_ctx->super.super.lib; + ctx->lib = lib; + + device_list = ibv_get_device_list(&num_devices); + if (!device_list || !num_devices) { + tl_debug(lib, "no ib devices available"); + status = UCC_ERR_NOT_SUPPORTED; + goto error; + } + + if (!strcmp(mcast_ctx_conf->ib_dev_name, "")) { + dev = device_list[0]; + devname = (char *)ibv_get_device_name(dev); + ctx->devname = ucc_malloc(strlen(devname)+16, "devname"); + if (!ctx->devname) { + status = UCC_ERR_NO_MEMORY; + goto error; + } + strncpy(ctx->devname, devname, strlen(devname)); + strncat(ctx->devname, ":1", 16); + user_provided_ib = 0; + } else { + ib_valid = 0; + /* user has provided the devname now make sure it is valid */ + for (i = 0; device_list[i]; ++i) { + if (!strcmp(ibv_get_device_name(device_list[i]), mcast_ctx_conf->ib_dev_name)) { + ib_valid = 1; + break; + } + } + if (!ib_valid) { + tl_debug(lib, "ib device %s not found", mcast_ctx_conf->ib_dev_name); + status = UCC_ERR_NOT_FOUND; + goto error; + } + ctx->devname = mcast_ctx_conf->ib_dev_name; + user_provided_ib = 1; + } + + ibv_free_device_list(device_list); + if (UCC_OK != ucc_tl_mlx5_probe_ip_over_ib(ctx->devname, &ip_oib_addr)) { + tl_debug(lib, "failed to get ipoib interface for devname %s", ctx->devname); + status = UCC_ERR_NOT_SUPPORTED; + if (!user_provided_ib) { + ucc_free(ctx->devname); + } + goto error; + } + + is_ipv4 = (ip_oib_addr.ss_family == AF_INET) ? 1 : 0; + in_src_addr = (struct sockaddr_in*)&ip_oib_addr; + + dst = inet_ntop((is_ipv4) ? AF_INET : AF_INET6, + &in_src_addr->sin_addr, addrstr, sizeof(addrstr) - 1); + if (NULL == dst) { + tl_error(lib, "inet_ntop failed"); + status = UCC_ERR_NO_RESOURCE; + goto error; + } + + tl_debug(ctx->lib, "devname %s, ipoib %s", ctx->devname, addrstr); + + ctx->channel = rdma_create_event_channel(); + if (!ctx->channel) { + tl_error(lib, "rdma_create_event_channel failed, errno %d", errno); + status = UCC_ERR_NO_RESOURCE; + goto error; + } + + memset(&dst_addr, 0, sizeof(struct sockaddr_storage)); + dst_addr.ss_family = is_ipv4 ? AF_INET : AF_INET6; + if (rdma_create_id(ctx->channel, &ctx->id, NULL, RDMA_PS_UDP)) { + tl_debug(lib, "failed to create rdma id, errno %d", errno); + status = UCC_ERR_NOT_SUPPORTED; + goto error; + } + + if (0 != rdma_resolve_addr(ctx->id, (struct sockaddr *)&ip_oib_addr, + (struct sockaddr *) &dst_addr, 1000)) { + tl_debug(lib, "failed to resolve rdma addr, errno %d", errno); + status = UCC_ERR_NOT_SUPPORTED; + goto error; + } + + if (rdma_get_cm_event(ctx->channel, &revent) < 0) { + tl_error(lib, "failed to get cm event, errno %d", errno); + status = UCC_ERR_NO_RESOURCE; + goto error; + } else if (revent->event != RDMA_CM_EVENT_ADDR_RESOLVED) { + tl_error(lib, "cm event is not resolved"); + if (rdma_ack_cm_event(revent) < 0) { + tl_error(lib, "rdma_ack_cm_event failed"); + } + status = UCC_ERR_NO_RESOURCE; + goto error; + } + + if (rdma_ack_cm_event(revent) < 0) { + tl_error(lib, "rdma_ack_cm_event failed"); + status = UCC_ERR_NO_RESOURCE; + goto error; + } + + ctx->ctx = ctx->id->verbs; + ctx->pd = ibv_alloc_pd(ctx->ctx); + if (!ctx->pd) { + tl_error(lib, "failed to allocate pd"); + status = UCC_ERR_NO_RESOURCE; + goto error; + } + + ib = strdup(ctx->devname); + ucc_string_split(ib, ":", 2, &ib_name, &port); + ucc_free(ib); + ctx->ib_port = atoi(port); + + /* Determine MTU */ + if (ibv_query_port(ctx->ctx, ctx->ib_port, &port_attr)) { + tl_error(lib, "couldn't query port in ctx create, errno %d", errno); + status = UCC_ERR_NO_RESOURCE; + goto error; + } + + + for (i = 0; i < UCC_TL_MLX5_MCAST_MAX_MTU_COUNT; i++) { + if (mtu_lookup[i][1] == port_attr.max_mtu) { + max_mtu = mtu_lookup[i][0]; + break; + } + } + + for (i = 0; i < UCC_TL_MLX5_MCAST_MAX_MTU_COUNT; i++) { + if (mtu_lookup[i][1] == port_attr.active_mtu) { + active_mtu = mtu_lookup[i][0]; + break; + } + } + + ctx->mtu = active_mtu; + + tl_debug(ctx->lib, "port active MTU is %d and port max MTU is %d", + active_mtu, max_mtu); + + if (port_attr.max_mtu < port_attr.active_mtu) { + tl_debug(ctx->lib, "port active MTU (%d) is smaller than port max MTU (%d)", + active_mtu, max_mtu); + } + + if (ibv_query_device(ctx->ctx, &device_attr)) { + tl_error(lib, "failed to query device in ctx create, errno %d", errno); + status = UCC_ERR_NO_RESOURCE; + goto error; + } + + tl_debug(ctx->lib, "MTU %d, MAX QP WR: %d, max sqr_wr: %d, max cq: %d, max cqe: %d", + ctx->mtu, device_attr.max_qp_wr, device_attr.max_srq_wr, + device_attr.max_cq, device_attr.max_cqe); + + ctx->max_qp_wr = device_attr.max_qp_wr; + status = ucc_mpool_init(&ctx->compl_objects_mp, 0, sizeof(ucc_tl_mlx5_mcast_p2p_completion_obj_t), 0, + UCC_CACHE_LINE_SIZE, 8, UINT_MAX, + &ucc_coll_task_mpool_ops, + UCC_THREAD_SINGLE, + "ucc_tl_mlx5_mcast_p2p_completion_obj_t"); + if (ucc_unlikely(UCC_OK != status)) { + tl_error(lib, "failed to initialize compl_objects_mp mpool"); + status = UCC_ERR_NO_MEMORY; + goto error; + } + + ctx->rcache = NULL; + status = ucc_tl_mlx5_mcast_setup_rcache(ctx); + if (UCC_OK != status) { + tl_error(lib, "failed to setup rcache"); + goto error; + } + + tl_debug(ctx->lib, "context setup complete: ctx %p", ctx); + return UCC_OK; + +error: + if (ctx->pd) { + ibv_dealloc_pd(ctx->pd); + } + if (ctx->id) { + rdma_destroy_id(ctx->id); + } + if (ctx->channel) { + rdma_destroy_event_channel(ctx->channel); + } + + return status; } diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c index 61051a4584..bf1ba9430a 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c @@ -189,3 +189,367 @@ ucc_status_t ucc_tl_mlx5_probe_ip_over_ib(char* ib_dev, struct return status; } +ucc_status_t ucc_tl_mlx5_mcast_join_mcast_post(ucc_tl_mlx5_mcast_coll_context_t *ctx, + struct sockaddr_in6 *net_addr, + int is_root) +{ + char buf[40]; + const char *dst; + + dst = inet_ntop(AF_INET6, net_addr, buf, 40); + if (NULL == dst) { + tl_error(ctx->lib, "inet_ntop failed"); + return UCC_ERR_NO_RESOURCE; + } + + tl_debug(ctx->lib, "joining addr: %s is_root %d", buf, is_root); + + if (rdma_join_multicast(ctx->id, (struct sockaddr*)net_addr, NULL)) { + tl_error(ctx->lib, "rdma_join_multicast failed errno %d", errno); + return UCC_ERR_NO_RESOURCE; + } + + return UCC_OK; +} + +ucc_status_t ucc_tl_mlx5_mcast_join_mcast_test(ucc_tl_mlx5_mcast_coll_context_t *ctx, + struct rdma_cm_event **event, + int is_root) +{ + char buf[40]; + const char *dst; + + if (rdma_get_cm_event(ctx->channel, event) < 0) { + if (EINTR != errno) { + tl_error(ctx->lib, "rdma_get_cm_event failed, errno %d %s", + errno, strerror(errno)); + return UCC_ERR_NO_RESOURCE; + } else { + return UCC_INPROGRESS; + } + } + + if (RDMA_CM_EVENT_MULTICAST_JOIN != (*event)->event) { + tl_error(ctx->lib, "failed to join multicast, is_root %d. unexpected event was" + " received: event=%d, str=%s, status=%d", + is_root, (*event)->event, rdma_event_str((*event)->event), + (*event)->status); + if (rdma_ack_cm_event(*event) < 0) { + tl_error(ctx->lib, "rdma_ack_cm_event failed"); + } + return UCC_ERR_NO_RESOURCE; + } + + dst = inet_ntop(AF_INET6, (*event)->param.ud.ah_attr.grh.dgid.raw, buf, 40); + if (NULL == dst) { + tl_error(ctx->lib, "inet_ntop failed"); + return UCC_ERR_NO_RESOURCE; + } + + tl_debug(ctx->lib, "is_root %d: joined dgid: %s, mlid 0x%x, sl %d", is_root, buf, + (*event)->param.ud.ah_attr.dlid, (*event)->param.ud.ah_attr.sl); + + return UCC_OK; + +} + +ucc_status_t ucc_tl_mlx5_setup_mcast_group_join_post(ucc_tl_mlx5_mcast_coll_comm_t *comm) +{ + ucc_status_t status; + struct sockaddr_in6 net_addr = {0,}; + + if (comm->rank == 0) { + net_addr.sin6_family = AF_INET6; + net_addr.sin6_flowinfo = comm->comm_id; + + status = ucc_tl_mlx5_mcast_join_mcast_post(comm->ctx, &net_addr, 1); + if (status < 0) { + tl_error(comm->lib, "rank 0 is unable to join mcast group"); + return status; + } + } + + return UCC_OK; +} + +ucc_status_t ucc_tl_mlx5_mcast_init_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx, + ucc_tl_mlx5_mcast_coll_comm_t *comm) +{ + struct ibv_qp_init_attr qp_init_attr = {0}; + + qp_init_attr.qp_type = IBV_QPT_UD; + qp_init_attr.send_cq = comm->scq; + qp_init_attr.recv_cq = comm->rcq; + qp_init_attr.sq_sig_all = 0; + qp_init_attr.cap.max_send_wr = comm->params.sx_depth; + qp_init_attr.cap.max_recv_wr = comm->params.rx_depth; + qp_init_attr.cap.max_inline_data = comm->params.sx_inline; + qp_init_attr.cap.max_send_sge = comm->params.sx_sge; + qp_init_attr.cap.max_recv_sge = comm->params.rx_sge; + + comm->mcast.qp = ibv_create_qp(ctx->pd, &qp_init_attr); + if (!comm->mcast.qp) { + tl_error(ctx->lib, "failed to create mcast qp, errno %d", errno); + return UCC_ERR_NO_RESOURCE; + } + + comm->max_inline = qp_init_attr.cap.max_inline_data; + + return UCC_OK; +} + +static ucc_status_t ucc_tl_mlx5_mcast_create_ah(ucc_tl_mlx5_mcast_coll_comm_t *comm) +{ + struct ibv_ah_attr ah_attr = { + .is_global = 1, + .grh = {.sgid_index = 0}, + .dlid = comm->mcast_lid, + .sl = DEF_SL, + .src_path_bits = DEF_SRC_PATH_BITS, + .port_num = comm->ctx->ib_port + }; + + memcpy(ah_attr.grh.dgid.raw, &comm->mgid, sizeof(ah_attr.grh.dgid.raw)); + + comm->mcast.ah = ibv_create_ah(comm->ctx->pd, &ah_attr); + if (!comm->mcast.ah) { + tl_error(comm->lib, "failed to create AH"); + return UCC_ERR_NO_RESOURCE; + } + return UCC_OK; +} + +ucc_status_t ucc_tl_mlx5_mcast_setup_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx, + ucc_tl_mlx5_mcast_coll_comm_t *comm) +{ + struct ibv_port_attr port_attr; + struct ibv_qp_attr attr; + uint16_t pkey; + + ibv_query_port(ctx->ctx, ctx->ib_port, &port_attr); + + for (ctx->pkey_index = 0; ctx->pkey_index < port_attr.pkey_tbl_len; + ++ctx->pkey_index) { + ibv_query_pkey(ctx->ctx, ctx->ib_port, ctx->pkey_index, &pkey); + if (pkey == DEF_PKEY) + break; + } + + if (ctx->pkey_index >= port_attr.pkey_tbl_len) { + ctx->pkey_index = 0; + ibv_query_pkey(ctx->ctx, ctx->ib_port, ctx->pkey_index, &pkey); + if (pkey) { + tl_debug(ctx->lib, "cannot find default pkey 0x%04x on port %d, using index 0 pkey:0x%04x", + DEF_PKEY, ctx->ib_port, pkey); + } else { + tl_error(ctx->lib, "cannot find valid PKEY"); + return UCC_ERR_NO_RESOURCE; + } + } + + attr.qp_state = IBV_QPS_INIT; + attr.pkey_index = ctx->pkey_index; + attr.port_num = ctx->ib_port; + attr.qkey = DEF_QKEY; + + if (ibv_modify_qp(comm->mcast.qp, &attr, + IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_QKEY)) { + tl_error(ctx->lib, "failed to move mcast qp to INIT, errno %d", errno); + return UCC_ERR_NO_RESOURCE; + } + + if (ibv_attach_mcast(comm->mcast.qp, &comm->mgid, comm->mcast_lid)) { + tl_error(ctx->lib, "failed to attach QP to the mcast group, errno %d", errno); + return UCC_ERR_NO_RESOURCE; + } + + /* Ok, now cycle to RTR on everyone */ + attr.qp_state = IBV_QPS_RTR; + if (ibv_modify_qp(comm->mcast.qp, &attr, IBV_QP_STATE)) { + tl_error(ctx->lib, "failed to modify QP to RTR, errno %d", errno); + return UCC_ERR_NO_RESOURCE; + } + + attr.qp_state = IBV_QPS_RTS; + attr.sq_psn = DEF_PSN; + if (ibv_modify_qp(comm->mcast.qp, &attr, IBV_QP_STATE | IBV_QP_SQ_PSN)) { + tl_error(ctx->lib, "failed to modify QP to RTS, errno %d", errno); + return UCC_ERR_NO_RESOURCE; + } + + /* Create the address handle */ + if (UCC_OK != ucc_tl_mlx5_mcast_create_ah(comm)) { + tl_error(ctx->lib, "failed to create adress handle"); + return UCC_ERR_NO_RESOURCE; + } + + return UCC_OK; +} + +ucc_status_t ucc_tl_mlx5_fini_mcast_group(ucc_tl_mlx5_mcast_coll_context_t *ctx, + ucc_tl_mlx5_mcast_coll_comm_t *comm) +{ + char buf[40]; + const char *dst; + + dst = inet_ntop(AF_INET6, &comm->mcast_addr, buf, 40); + if (NULL == dst) { + tl_error(comm->lib, "inet_ntop failed"); + return UCC_ERR_NO_RESOURCE; + } + + tl_debug(ctx->lib, "mcast leave: ctx %p, comm %p, dgid: %s", ctx, comm, buf); + + if (rdma_leave_multicast(ctx->id, (struct sockaddr*)&comm->mcast_addr)) { + tl_error(comm->lib, "mcast rmda_leave_multicast failed"); + return UCC_ERR_NO_RESOURCE; + } + + return UCC_OK; +} + +ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm) +{ + int ret; + ucc_status_t status; + + tl_debug(comm->lib, "cleaning mcast comm: %p, id %d, mlid %x", + comm, comm->comm_id, comm->mcast_lid); + + if (UCC_OK != (status = ucc_tl_mlx5_mcast_reliable(comm))) { + // TODO handle (UCC_INPROGRESS == ret) + tl_error(comm->lib, "couldn't clean mcast team: relibality progress status %d", + status); + return status; + } + + if (comm->mcast.qp) { + ret = ibv_detach_mcast(comm->mcast.qp, &comm->mgid, comm->mcast_lid); + if (ret) { + tl_error(comm->lib, "couldn't detach QP, ret %d, errno %d", ret, errno); + return UCC_ERR_NO_RESOURCE; + } + } + + if (comm->mcast.qp) { + ret = ibv_destroy_qp(comm->mcast.qp); + if (ret) { + tl_error(comm->lib, "failed to destroy QP %d", ret); + return UCC_ERR_NO_RESOURCE; + } + } + + if (comm->rcq) { + ret = ibv_destroy_cq(comm->rcq); + if (ret) { + tl_error(comm->lib, "couldn't destroy rcq"); + return UCC_ERR_NO_RESOURCE; + } + } + + if (comm->scq) { + ret = ibv_destroy_cq(comm->scq); + if (ret) { + tl_error(comm->lib, "couldn't destroy scq"); + return UCC_ERR_NO_RESOURCE; + } + } + + if (comm->grh_mr) { + ret = ibv_dereg_mr(comm->grh_mr); + if (ret) { + tl_error(comm->lib, "couldn't destroy grh mr"); + return UCC_ERR_NO_RESOURCE; + } + } + if (comm->grh_buf) { + ucc_free(comm->grh_buf); + } + + if (comm->pp) { + ucc_free(comm->pp); + } + + if (comm->pp_mr) { + ret = ibv_dereg_mr(comm->pp_mr); + if (ret) { + tl_error(comm->lib, "couldn't destroy pp mr"); + return UCC_ERR_NO_RESOURCE; + } + } + + if (comm->pp_buf) { + ucc_free(comm->pp_buf); + } + + if (comm->call_rwr) { + ucc_free(comm->call_rwr); + } + + if (comm->call_rsgs) { + ucc_free(comm->call_rsgs); + } + + if (comm->mcast.ah) { + ret = ibv_destroy_ah(comm->mcast.ah); + if (ret) { + tl_error(comm->lib, "couldn't destroy ah"); + return UCC_ERR_NO_RESOURCE; + } + } + + if (comm->mcast_lid) { + status = ucc_tl_mlx5_fini_mcast_group(comm->ctx, comm); + if (status) { + tl_error(comm->lib, "couldn't leave mcast group"); + return status; + } + } + + if (comm->ctx->params.print_nack_stats) { + tl_debug(comm->lib, "comm_id %d, comm_size %d, comm->psn %d, rank %d, " + "nacks counter %d, n_mcast_rel %d", + comm->comm_id, comm->commsize, comm->psn, comm->rank, + comm->nacks_counter, comm->n_mcast_reliable); + } + + if (comm->p2p_ctx != NULL) { + ucc_free(comm->p2p_ctx); + } + + ucc_free(comm); + + return UCC_OK; +} + +ucc_status_t ucc_tl_mlx5_clean_mcast_ctx(ucc_tl_mlx5_mcast_coll_context_t *ctx) +{ + tl_debug(ctx->lib, "cleaning mcast ctx: %p", ctx); + + if (ctx->rcache) { + ucc_rcache_destroy(ctx->rcache); + } + + if (ctx->pd) { + if (ibv_dealloc_pd(ctx->pd)) { + tl_error(ctx->lib, "ibv_dealloc_pd failed errno %d", errno); + return UCC_ERR_NO_RESOURCE; + } + } + + if (rdma_destroy_id(ctx->id)) { + tl_error(ctx->lib, "rdma_destroy_id failed errno %d", errno); + return UCC_ERR_NO_RESOURCE; + } + + rdma_destroy_event_channel(ctx->channel); + + if (!strcmp(ctx->params.ib_dev_name, "")) { + ucc_free(ctx->devname); + } + + ucc_free(ctx); + + return UCC_OK; +} + diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c index 31044fe8b3..f56bc3c1a1 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c @@ -8,6 +8,7 @@ #include "tl_mlx5.h" #include "tl_mlx5_mcast_coll.h" #include "coll_score/ucc_coll_score.h" +#include "tl_mlx5_mcast_helper.h" ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context, /* NOLINT */ ucc_tl_mlx5_mcast_team_t **mcast_team, /* NOLINT */ @@ -18,3 +19,112 @@ ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_cont return UCC_OK; } +ucc_status_t ucc_tl_mlx5_mcast_coll_setup_comm_resources(ucc_tl_mlx5_mcast_coll_comm_t *comm) +{ + ucc_status_t status; + size_t page_size; + int buf_size, i, ret; + + status = ucc_tl_mlx5_mcast_init_qps(comm->ctx, comm); + if (UCC_OK != status) { + goto error; + } + + status = ucc_tl_mlx5_mcast_setup_qps(comm->ctx, comm); + if (UCC_OK != status) { + goto error; + } + + page_size = ucc_get_page_size(); + buf_size = comm->ctx->mtu; + + // Comm receiving buffers. + ret = posix_memalign((void**)&comm->call_rwr, page_size, sizeof(struct ibv_recv_wr) * + comm->params.rx_depth); + if (ret) { + tl_error(comm->ctx->lib, "posix_memalign failed"); + return UCC_ERR_NO_MEMORY; + } + + ret = posix_memalign((void**)&comm->call_rsgs, page_size, sizeof(struct ibv_sge) * + comm->params.rx_depth * 2); + if (ret) { + tl_error(comm->ctx->lib, "posix_memalign failed"); + return UCC_ERR_NO_MEMORY; + } + + comm->pending_recv = 0; + comm->buf_n = comm->params.rx_depth * 2; + + ret = posix_memalign((void**) &comm->pp_buf, page_size, buf_size * comm->buf_n); + if (ret) { + tl_error(comm->ctx->lib, "posix_memalign failed"); + return UCC_ERR_NO_MEMORY; + } + + memset(comm->pp_buf, 0, buf_size * comm->buf_n); + + comm->pp_mr = ibv_reg_mr(comm->ctx->pd, comm->pp_buf, buf_size * comm->buf_n, + IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_LOCAL_WRITE); + if (!comm->pp_mr) { + tl_error(comm->ctx->lib, "could not register pp_buf mr, errno %d", errno); + status = UCC_ERR_NO_MEMORY; + goto error; + } + + ret = posix_memalign((void**) &comm->pp, page_size, sizeof(struct + pp_packet) * comm->buf_n); + if (ret) { + tl_error(comm->ctx->lib, "posix_memalign failed"); + return UCC_ERR_NO_MEMORY; + } + + for (i = 0; i < comm->buf_n; i++) { + ucc_list_head_init(&comm->pp[i].super); + + comm->pp[i].buf = (uintptr_t) comm->pp_buf + i * buf_size; + comm->pp[i].context = 0; + + ucc_list_add_tail(&comm->bpool, &comm->pp[i].super); + } + + comm->mcast.swr.wr.ud.ah = comm->mcast.ah; + comm->mcast.swr.num_sge = 1; + comm->mcast.swr.sg_list = &comm->mcast.ssg; + comm->mcast.swr.opcode = IBV_WR_SEND_WITH_IMM; + comm->mcast.swr.wr.ud.remote_qpn = MULTICAST_QPN; + comm->mcast.swr.wr.ud.remote_qkey = DEF_QKEY; + comm->mcast.swr.next = NULL; + + for (i = 0; i < comm->params.rx_depth; i++) { + comm->call_rwr[i].sg_list = &comm->call_rsgs[2 * i]; + comm->call_rwr[i].num_sge = 2; + comm->call_rwr[i].wr_id = MCAST_BCASTRECV_WR; + comm->call_rsgs[2 * i].length = GRH_LENGTH; + comm->call_rsgs[2 * i].addr = (uintptr_t)comm->grh_buf; + comm->call_rsgs[2 * i].lkey = comm->grh_mr->lkey; + comm->call_rsgs[2 * i + 1].lkey = comm->pp_mr->lkey; + comm->call_rsgs[2 * i + 1].length = comm->max_per_packet; + } + + status = ucc_tl_mlx5_mcast_post_recv_buffers(comm); + if (UCC_OK != status) { + goto error; + } + + memset(comm->parents, 0, sizeof(comm->parents)); + memset(comm->children, 0, sizeof(comm->children)); + + comm->nacks_counter = 0; + comm->tx = 0; + comm->n_prep_reliable = 0; + comm->n_mcast_reliable = 0; + comm->reliable_in_progress = 0; + comm->recv_drop_packet_in_progress = 0; + + return status; + +error: + ucc_tl_mlx5_clean_mcast_comm(comm); + return status; +}