From 7c333e1f539fa354ed2390b21a9823cad3e00e51 Mon Sep 17 00:00:00 2001 From: Mamzi Bayatpour <77160721+MamziB@users.noreply.github.com> Date: Mon, 24 Jun 2024 03:41:59 -0700 Subject: [PATCH] TL/MLX5: one-sided mcast reliability init (#980) --- src/components/tl/mlx5/Makefile.am | 30 ++- .../tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c | 78 ++++++ .../tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.h | 9 + src/components/tl/mlx5/mcast/tl_mlx5_mcast.h | 174 ++++++++----- .../tl/mlx5/mcast/tl_mlx5_mcast_context.c | 3 +- .../tl/mlx5/mcast/tl_mlx5_mcast_helper.c | 134 ++++++++++ .../tl/mlx5/mcast/tl_mlx5_mcast_helper.h | 6 + .../tl_mlx5_mcast_one_sided_reliability.c | 228 ++++++++++++++++++ .../tl_mlx5_mcast_one_sided_reliability.h | 19 ++ .../mlx5/mcast/tl_mlx5_mcast_service_coll.c | 83 +++++++ .../mlx5/mcast/tl_mlx5_mcast_service_coll.h | 18 ++ .../tl/mlx5/mcast/tl_mlx5_mcast_team.c | 53 +--- 12 files changed, 722 insertions(+), 113 deletions(-) create mode 100644 src/components/tl/mlx5/mcast/tl_mlx5_mcast_one_sided_reliability.c create mode 100644 src/components/tl/mlx5/mcast/tl_mlx5_mcast_one_sided_reliability.h create mode 100644 src/components/tl/mlx5/mcast/tl_mlx5_mcast_service_coll.c create mode 100644 src/components/tl/mlx5/mcast/tl_mlx5_mcast_service_coll.h diff --git a/src/components/tl/mlx5/Makefile.am b/src/components/tl/mlx5/Makefile.am index 01b94cfa1d..a7bc249f87 100644 --- a/src/components/tl/mlx5/Makefile.am +++ b/src/components/tl/mlx5/Makefile.am @@ -12,19 +12,23 @@ alltoall = \ alltoall/alltoall_inline.h \ alltoall/alltoall_coll.c -mcast = \ - mcast/tl_mlx5_mcast_context.c \ - mcast/tl_mlx5_mcast.h \ - mcast/tl_mlx5_mcast_coll.c \ - mcast/tl_mlx5_mcast_coll.h \ - mcast/tl_mlx5_mcast_rcache.h \ - mcast/tl_mlx5_mcast_rcache.c \ - mcast/p2p/ucc_tl_mlx5_mcast_p2p.h \ - mcast/p2p/ucc_tl_mlx5_mcast_p2p.c \ - mcast/tl_mlx5_mcast_progress.h \ - mcast/tl_mlx5_mcast_progress.c \ - mcast/tl_mlx5_mcast_helper.h \ - mcast/tl_mlx5_mcast_helper.c \ +mcast = \ + mcast/tl_mlx5_mcast_context.c \ + mcast/tl_mlx5_mcast.h \ + mcast/tl_mlx5_mcast_coll.c \ + mcast/tl_mlx5_mcast_coll.h \ + mcast/tl_mlx5_mcast_rcache.h \ + mcast/tl_mlx5_mcast_rcache.c \ + mcast/p2p/ucc_tl_mlx5_mcast_p2p.h \ + mcast/p2p/ucc_tl_mlx5_mcast_p2p.c \ + mcast/tl_mlx5_mcast_progress.h \ + mcast/tl_mlx5_mcast_progress.c \ + mcast/tl_mlx5_mcast_helper.h \ + mcast/tl_mlx5_mcast_helper.c \ + mcast/tl_mlx5_mcast_service_coll.h \ + mcast/tl_mlx5_mcast_service_coll.c \ + mcast/tl_mlx5_mcast_one_sided_reliability.h \ + mcast/tl_mlx5_mcast_one_sided_reliability.c \ mcast/tl_mlx5_mcast_team.c sources = \ diff --git a/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c b/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c index 11be1473d2..ea57bfa89c 100644 --- a/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c +++ b/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c @@ -139,3 +139,81 @@ ucc_status_t ucc_tl_mlx5_mcast_p2p_recv_nb(void *dst, size_t size, ucc_rank_t return status; } + +ucc_status_t ucc_tl_mlx5_one_sided_p2p_put(void* src, void* remote_addr, size_t length, + uint32_t lkey, uint32_t rkey, ucc_rank_t target_rank, + uint64_t wr_id, ucc_tl_mlx5_mcast_coll_comm_t *comm) +{ + struct ibv_send_wr swr = {0}; + struct ibv_sge ssg = {0}; + struct ibv_send_wr *bad_wr; + int rc; + + if (UINT32_MAX < length) { + tl_error(comm->lib, "msg too large for p2p put"); + return UCC_ERR_NOT_SUPPORTED; + } + + ssg.addr = (uint64_t)src; + ssg.length = (uint32_t)length; + ssg.lkey = lkey; + swr.sg_list = &ssg; + swr.num_sge = 1; + + swr.opcode = IBV_WR_RDMA_WRITE; + swr.wr_id = wr_id; + swr.send_flags = IBV_SEND_SIGNALED; + swr.wr.rdma.remote_addr = (uint64_t)remote_addr; + swr.wr.rdma.rkey = rkey; + swr.next = NULL; + + tl_trace(comm->lib, "RDMA WRITE to rank %d size length %ld remote address %p rkey %d lkey %d src %p", + target_rank, length, remote_addr, rkey, lkey, src); + + if (0 != (rc = ibv_post_send(comm->mcast.rc_qp[target_rank], &swr, &bad_wr))) { + tl_error(comm->lib, "RDMA Write failed rc %d rank %d remote addresss %p rkey %d", + rc, target_rank, remote_addr, rkey); + return UCC_ERR_NO_MESSAGE; + } + + return UCC_OK; +} + +ucc_status_t ucc_tl_mlx5_one_sided_p2p_get(void* src, void* remote_addr, size_t length, + uint32_t lkey, uint32_t rkey, ucc_rank_t target_rank, + uint64_t wr_id, ucc_tl_mlx5_mcast_coll_comm_t *comm) +{ + struct ibv_send_wr swr = {0}; + struct ibv_sge ssg = {0}; + struct ibv_send_wr *bad_wr; + int rc; + + if (UINT32_MAX < length) { + tl_error(comm->lib, "msg too large for p2p get"); + return UCC_ERR_NOT_SUPPORTED; + } + + ssg.addr = (uint64_t)src; + ssg.length = (uint32_t)length; + ssg.lkey = lkey; + swr.sg_list = &ssg; + swr.num_sge = 1; + + swr.opcode = IBV_WR_RDMA_READ; + swr.wr_id = wr_id; + swr.send_flags = IBV_SEND_SIGNALED; + swr.wr.rdma.remote_addr = (uint64_t)remote_addr; + swr.wr.rdma.rkey = rkey; + swr.next = NULL; + + tl_trace(comm->lib, "RDMA READ to rank %d size length %ld remote address %p rkey %d lkey %d src %p", + target_rank, length, remote_addr, rkey, lkey, src); + + if (0 != (rc = ibv_post_send(comm->mcast.rc_qp[target_rank], &swr, &bad_wr))) { + tl_error(comm->lib, "RDMA Read failed rc %d rank %d remote addresss %p rkey %d", + rc, target_rank, remote_addr, rkey); + return UCC_ERR_NO_MESSAGE; + } + + return UCC_OK; +} diff --git a/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.h b/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.h index e82f7546a7..48b6bad4c5 100644 --- a/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.h +++ b/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.h @@ -16,3 +16,12 @@ ucc_status_t ucc_tl_mlx5_mcast_p2p_recv_nb(void* dst, size_t size, ucc_rank_t rank, void *context, ucc_tl_mlx5_mcast_p2p_completion_obj_t *obj); + +ucc_status_t ucc_tl_mlx5_one_sided_p2p_put(void* src, void* remote_addr, size_t length, + uint32_t lkey, uint32_t rkey, ucc_rank_t target_rank, + uint64_t wr_id, ucc_tl_mlx5_mcast_coll_comm_t *comm); + +ucc_status_t ucc_tl_mlx5_one_sided_p2p_get(void* src, void* remote_addr, size_t length, + uint32_t lkey, uint32_t rkey, ucc_rank_t target_rank, + uint64_t wr_id, ucc_tl_mlx5_mcast_coll_comm_t *comm); + diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h index 711053f1b2..1208226bda 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h @@ -33,6 +33,14 @@ #define DROP_THRESHOLD 1000 #define MAX_COMM_POW2 32 +/* Allgather RDMA-based reliability designs */ +#define ONE_SIDED_RELIABILITY_MAX_TEAM_SIZE 1024 +#define ONE_SIDED_NO_RELIABILITY 0 +#define ONE_SIDED_SYNCHRONOUS_PROTO 1 +#define ONE_SIDED_ASYNCHRONOUS_PROTO 2 +#define ONE_SIDED_SLOTS_COUNT 2 /* number of memory slots during async design */ +#define ONE_SIDED_SLOTS_INFO_SIZE sizeof(uint32_t) /* size of metadata prepended to each slots in bytes */ + enum { MCAST_PROTO_EAGER, /* Internal staging buffers */ MCAST_PROTO_ZCOPY @@ -136,6 +144,7 @@ typedef struct ucc_tl_mlx5_mcast_coll_context { int ib_port; int pkey_index; int mtu; + uint16_t port_lid; struct rdma_cm_id *id; struct rdma_event_channel *channel; ucc_mpool_t compl_objects_mp; @@ -174,6 +183,10 @@ struct mcast_ctx { struct ibv_ah *ah; struct ibv_send_wr swr; struct ibv_sge ssg; + // RC connection info for supporing one-sided based relibality + struct ibv_qp **rc_qp; + uint16_t *rc_lid; + union ibv_gid *rc_gid; }; struct packet { @@ -183,65 +196,109 @@ struct packet { int comm_id; }; +typedef struct ucc_tl_mlx5_mcast_slot_mem_info { + uint64_t remote_addr; + uint32_t rkey; +} ucc_tl_mlx5_mcast_slot_mem_info_t; + +typedef struct ucc_tl_mlx5_one_sided_reliable_team_info { + ucc_tl_mlx5_mcast_slot_mem_info_t slot_mem; + uint16_t port_lid; + uint32_t rc_qp_num[ONE_SIDED_RELIABILITY_MAX_TEAM_SIZE]; +} ucc_tl_mlx5_one_sided_reliable_team_info_t; + +typedef struct ucc_tl_mlx5_mcast_one_sided_reliability_comm { + /* all the info required for establishing a reliable connection as + * well as temp slots memkeys that all processes in the team need + * to be aware of*/ + ucc_tl_mlx5_one_sided_reliable_team_info_t *info; + /* holds all the remote-addr/rkey of sendbuf from processes in the team + * used in sync design. it needs to be set during each mcast-allgather call + * after sendbuf registration */ + ucc_tl_mlx5_mcast_slot_mem_info_t *sendbuf_memkey_list; + /* counter for each target recv packet */ + uint32_t *recvd_pkts_tracker; + /* holds the remote targets' collective call counter. it is used to check + * if remote temp slot is ready for RDMA READ in async design */ + uint32_t *remote_slot_info; + struct ibv_mr *remote_slot_info_mr; + int reliability_scheme_msg_threshold; + /* mem address and mem keys of the temp slots in async design */ + char *slots_buffer; + struct ibv_mr *slots_mr; + /* size of a temp slot in async design */ + int slot_size; + /* coll req that is used during the oob service calls */ + ucc_service_coll_req_t *reliability_req; +} ucc_tl_mlx5_mcast_one_sided_reliability_comm_t; + +typedef struct ucc_tl_mlx5_mcast_service_coll { + ucc_status_t (*bcast_post) (void*, void*, size_t, ucc_rank_t, ucc_service_coll_req_t**); + ucc_status_t (*allgather_post) (void*, void*, void*, size_t, ucc_service_coll_req_t**); + ucc_status_t (*barrier_post) (void*, ucc_service_coll_req_t**); + ucc_status_t (*coll_test) (ucc_service_coll_req_t*); +} ucc_tl_mlx5_mcast_service_coll_t; + typedef struct ucc_tl_mlx5_mcast_coll_comm { - struct pp_packet dummy_packet; - ucc_tl_mlx5_mcast_coll_context_t *ctx; - ucc_tl_mlx5_mcast_coll_comm_init_spec_t params; - ucc_tl_mlx5_mcast_p2p_interface_t p2p; - int tx; - struct ibv_cq *scq; - struct ibv_cq *rcq; - ucc_rank_t rank; - ucc_rank_t commsize; - char *grh_buf; - struct ibv_mr *grh_mr; - uint16_t mcast_lid; - union ibv_gid mgid; - unsigned max_inline; - size_t max_eager; - int max_per_packet; - int pending_send; - int pending_recv; - struct ibv_mr *pp_mr; - char *pp_buf; - struct pp_packet *pp; - uint32_t psn; - uint32_t last_psn; - uint32_t racks_n; - uint32_t sacks_n; - uint32_t last_acked; - uint32_t naks_n; - uint32_t child_n; - uint32_t parent_n; - int buf_n; - struct packet p2p_pkt[MAX_COMM_POW2]; - struct packet p2p_spkt[MAX_COMM_POW2]; - ucc_list_link_t bpool; - ucc_list_link_t pending_q; - struct mcast_ctx mcast; - int reliable_in_progress; - int recv_drop_packet_in_progress; - struct ibv_recv_wr *call_rwr; - struct ibv_sge *call_rsgs; - uint64_t timer; - int stalled; - int comm_id; - void *p2p_ctx; - ucc_base_lib_t *lib; - struct sockaddr_in6 mcast_addr; - ucc_rank_t parents[MAX_COMM_POW2]; - ucc_rank_t children[MAX_COMM_POW2]; - int nack_requests; - int nacks_counter; - int n_prep_reliable; - int n_mcast_reliable; - int wsize; - ucc_tl_mlx5_mcast_join_info_t *group_setup_info; - ucc_service_coll_req_t *group_setup_info_req; - ucc_status_t (*bcast_post) (void*, void*, size_t, ucc_rank_t, ucc_service_coll_req_t**); - ucc_status_t (*bcast_test) (ucc_service_coll_req_t*); - struct rdma_cm_event *event; - struct pp_packet *r_window[1]; // do not add any new variable after here + struct pp_packet dummy_packet; + ucc_tl_mlx5_mcast_coll_context_t *ctx; + ucc_tl_mlx5_mcast_coll_comm_init_spec_t params; + ucc_tl_mlx5_mcast_p2p_interface_t p2p; + int tx; + struct ibv_cq *scq; + struct ibv_cq *rcq; + struct ibv_srq *srq; + ucc_rank_t rank; + ucc_rank_t commsize; + char *grh_buf; + struct ibv_mr *grh_mr; + uint16_t mcast_lid; + union ibv_gid mgid; + unsigned max_inline; + size_t max_eager; + int max_per_packet; + int pending_send; + int pending_recv; + struct ibv_mr *pp_mr; + char *pp_buf; + struct pp_packet *pp; + uint32_t psn; + uint32_t last_psn; + uint32_t racks_n; + uint32_t sacks_n; + uint32_t last_acked; + uint32_t naks_n; + uint32_t child_n; + uint32_t parent_n; + int buf_n; + struct packet p2p_pkt[MAX_COMM_POW2]; + struct packet p2p_spkt[MAX_COMM_POW2]; + ucc_list_link_t bpool; + ucc_list_link_t pending_q; + struct mcast_ctx mcast; + int reliable_in_progress; + int recv_drop_packet_in_progress; + struct ibv_recv_wr *call_rwr; + struct ibv_sge *call_rsgs; + uint64_t timer; + int stalled; + int comm_id; + void *p2p_ctx; + ucc_base_lib_t *lib; + struct sockaddr_in6 mcast_addr; + ucc_rank_t parents[MAX_COMM_POW2]; + ucc_rank_t children[MAX_COMM_POW2]; + int nack_requests; + int nacks_counter; + int n_prep_reliable; + int n_mcast_reliable; + int wsize; + ucc_tl_mlx5_mcast_join_info_t *group_setup_info; + ucc_service_coll_req_t *group_setup_info_req; + ucc_tl_mlx5_mcast_service_coll_t service_coll; + struct rdma_cm_event *event; + ucc_tl_mlx5_mcast_one_sided_reliability_comm_t one_sided; + struct pp_packet *r_window[1]; // note: do not add any new variable after here } ucc_tl_mlx5_mcast_coll_comm_t; typedef struct ucc_tl_mlx5_mcast_team { @@ -313,6 +370,7 @@ typedef struct ucc_tl_mlx5_mcast_oob_p2p_context { ucc_rank_t my_team_rank; ucc_subset_t subset; ucc_base_lib_t *lib; + int tmp_buf; } ucc_tl_mlx5_mcast_oob_p2p_context_t; static inline struct pp_packet* ucc_tl_mlx5_mcast_buf_get_free(ucc_tl_mlx5_mcast_coll_comm_t* comm) diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_context.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_context.c index 0a8e52f191..0756ac142d 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_context.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_context.c @@ -216,7 +216,8 @@ ucc_status_t ucc_tl_mlx5_mcast_context_init(ucc_tl_mlx5_mcast_context_t *cont } } - ctx->mtu = active_mtu; + ctx->mtu = active_mtu; + ctx->port_lid = port_attr.lid; tl_debug(ctx->lib, "port active MTU is %d and port max MTU is %d", active_mtu, max_mtu); diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c index b8e13c08ad..f57daeab5e 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c @@ -393,6 +393,140 @@ ucc_status_t ucc_tl_mlx5_mcast_setup_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx, return UCC_OK; } +ucc_status_t ucc_tl_mlx5_mcast_create_rc_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx, + ucc_tl_mlx5_mcast_coll_comm_t *comm) +{ + int i = 0, j = 0; + struct ibv_srq_init_attr srq_init_attr; + struct ibv_qp_init_attr qp_init_attr; + + /* create srq for this RC connection */ + memset(&srq_init_attr, 0, sizeof(srq_init_attr)); + srq_init_attr.attr.max_wr = comm->params.rx_depth; + srq_init_attr.attr.max_sge = 2; + + comm->srq = ibv_create_srq(ctx->pd, &srq_init_attr); + if (!comm->srq) { + tl_error(ctx->lib, "ibv_create_srq() failed"); + return UCC_ERR_NO_RESOURCE; + } + + comm->mcast.rc_qp = ucc_calloc(1, comm->commsize * sizeof(struct ibv_qp *), "ibv_qp* list"); + if (!comm->mcast.rc_qp) { + tl_error(ctx->lib, "failed to allocate memory for ibv_qp*"); + goto failed; + } + + /* create RC qp */ + for (i = 0; i < comm->commsize; i++) { + memset(&qp_init_attr, 0, sizeof(qp_init_attr)); + + qp_init_attr.srq = comm->srq; + qp_init_attr.qp_type = IBV_QPT_RC; + qp_init_attr.send_cq = comm->scq; + qp_init_attr.recv_cq = comm->rcq; + qp_init_attr.sq_sig_all = 0; + qp_init_attr.cap.max_send_wr = comm->params.sx_depth; + qp_init_attr.cap.max_recv_wr = 0; // has srq + qp_init_attr.cap.max_inline_data = 0; + qp_init_attr.cap.max_send_sge = comm->params.sx_sge; + qp_init_attr.cap.max_recv_sge = comm->params.rx_sge; + + comm->mcast.rc_qp[i] = ibv_create_qp(ctx->pd, &qp_init_attr); + if (!comm->mcast.rc_qp[i]) { + tl_error(ctx->lib, "Failed to create mcast RC qp index %d, errno %d", i, errno); + goto failed; + } + } + + return UCC_OK; + +failed: + for (j=0; jmcast.rc_qp[j])) { + tl_error(comm->lib, "ibv_destroy_qp failed"); + return UCC_ERR_NO_RESOURCE; + } + } + + if (ibv_destroy_srq(comm->srq)) { + tl_error(comm->lib, "ibv_destroy_srq failed"); + return UCC_ERR_NO_RESOURCE; + } + + ucc_free(comm->mcast.rc_qp); + + return UCC_ERR_NO_RESOURCE; +} + +ucc_status_t ucc_tl_mlx5_mcast_modify_rc_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx, + ucc_tl_mlx5_mcast_coll_comm_t *comm) +{ + ucc_rank_t my_rank = comm->rank; + struct ibv_qp_attr attr; + int i; + + for (i = 0; i < comm->commsize; i++) { + memset(&attr, 0, sizeof(attr)); + + attr.qp_state = IBV_QPS_INIT; + attr.pkey_index = 0; + attr.port_num = ctx->ib_port; + attr.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | + IBV_ACCESS_REMOTE_READ | + IBV_ACCESS_REMOTE_ATOMIC; + + if (ibv_modify_qp(comm->mcast.rc_qp[i], &attr, + IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | + IBV_QP_ACCESS_FLAGS)) { + tl_error(ctx->lib, "Failed to move rc qp to INIT, errno %d", errno); + return UCC_ERR_NO_RESOURCE; + } + + memset(&attr, 0, sizeof(attr)); + + attr.qp_state = IBV_QPS_RTR; + attr.path_mtu = IBV_MTU_4096; + attr.dest_qp_num = comm->one_sided.info[i].rc_qp_num[my_rank]; + attr.rq_psn = DEF_PSN; + attr.max_dest_rd_atomic = 16; + attr.min_rnr_timer = 12; + attr.ah_attr.is_global = 0; + attr.ah_attr.dlid = comm->mcast.rc_lid[i]; + attr.ah_attr.dlid = comm->one_sided.info[i].port_lid; + attr.ah_attr.sl = DEF_SL; + attr.ah_attr.src_path_bits = 0; + attr.ah_attr.port_num = ctx->ib_port; + + tl_debug(comm->lib, "Connecting to rc qp to rank %d with lid %d qp_num %d port_num %d", + i, attr.ah_attr.dlid, attr.dest_qp_num, attr.ah_attr.port_num); + + if (ibv_modify_qp(comm->mcast.rc_qp[i], &attr, + IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN + | IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER)) { + tl_error(ctx->lib, "Failed to modify rc QP index %d to RTR, errno %d", i, errno); + return UCC_ERR_NO_RESOURCE; + } + + memset(&attr, 0, sizeof(attr)); + + attr.qp_state = IBV_QPS_RTS; + attr.sq_psn = DEF_PSN; + attr.timeout = 14; + attr.retry_cnt = 7; + attr.rnr_retry = 7; /* infinite */ + attr.max_rd_atomic = 1; + if (ibv_modify_qp(comm->mcast.rc_qp[i], &attr, + IBV_QP_STATE | IBV_QP_SQ_PSN | IBV_QP_TIMEOUT | + IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_MAX_QP_RD_ATOMIC)) { + tl_error(ctx->lib, "Failed to modify rc QP index %i to RTS, errno %d", i, errno); + return UCC_ERR_NO_RESOURCE; + } + } + + return UCC_OK; +} + ucc_status_t ucc_tl_mlx5_fini_mcast_group(ucc_tl_mlx5_mcast_coll_context_t *ctx, ucc_tl_mlx5_mcast_coll_comm_t *comm) { diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h index 427039316d..9d66f3453e 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h @@ -363,6 +363,12 @@ ucc_status_t ucc_tl_mlx5_mcast_init_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx, ucc_status_t ucc_tl_mlx5_mcast_setup_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx, ucc_tl_mlx5_mcast_coll_comm_t *comm); +ucc_status_t ucc_tl_mlx5_mcast_create_rc_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx, + ucc_tl_mlx5_mcast_coll_comm_t *comm); + +ucc_status_t ucc_tl_mlx5_mcast_modify_rc_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx, + ucc_tl_mlx5_mcast_coll_comm_t *comm); + ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm); ucc_status_t ucc_tl_mlx5_mcast_join_mcast_post(ucc_tl_mlx5_mcast_coll_context_t *ctx, diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_one_sided_reliability.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_one_sided_reliability.c new file mode 100644 index 0000000000..85d63a82d0 --- /dev/null +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_one_sided_reliability.c @@ -0,0 +1,228 @@ +/** + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include "tl_mlx5_mcast_one_sided_reliability.h" + +static ucc_status_t ucc_tl_mlx5_mcast_one_sided_setup_reliability_buffers(ucc_base_team_t *team) +{ + ucc_tl_mlx5_team_t *tl_team = ucc_derived_of(team, ucc_tl_mlx5_team_t); + ucc_status_t status = UCC_OK; + ucc_tl_mlx5_mcast_coll_comm_t *comm = tl_team->mcast->mcast_comm; + int one_sided_total_slots_size, i; + + /* this array keeps track of the number of recv packets from each process + * used in all the protocols */ + comm->one_sided.recvd_pkts_tracker = ucc_calloc(1, comm->commsize * sizeof(uint32_t), + "one_sided.recvd_pkts_tracker"); + if (!comm->one_sided.recvd_pkts_tracker) { + tl_error(comm->lib, "unable to malloc for one_sided.recvd_pkts_tracker"); + status = UCC_ERR_NO_MEMORY; + goto failed; + } + + comm->one_sided.sendbuf_memkey_list = ucc_calloc + (1, comm->commsize * sizeof(ucc_tl_mlx5_mcast_slot_mem_info_t), + "one_sided.sendbuf_memkey_list"); + if (!comm->one_sided.sendbuf_memkey_list) { + tl_error(comm->lib, "unable to malloc for one_sided.sendbuf_memkey_list"); + status = UCC_ERR_NO_MEMORY; + goto failed; + } + + /* below data structures are used in async design only */ + comm->one_sided.slot_size = comm->one_sided.reliability_scheme_msg_threshold + + ONE_SIDED_SLOTS_INFO_SIZE; + one_sided_total_slots_size = comm->one_sided.slot_size * + ONE_SIDED_SLOTS_COUNT * sizeof(char); + comm->one_sided.slots_buffer = (char *)ucc_calloc(1, one_sided_total_slots_size, + "one_sided.slots_buffer"); + if (!comm->one_sided.slots_buffer) { + tl_error(comm->lib, "unable to malloc for one_sided.slots_buffer"); + status = UCC_ERR_NO_MEMORY; + goto failed; + } + comm->one_sided.slots_mr = ibv_reg_mr(comm->ctx->pd, comm->one_sided.slots_buffer, + one_sided_total_slots_size, IBV_ACCESS_LOCAL_WRITE | + IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_WRITE); + if (!comm->one_sided.slots_mr) { + tl_error(comm->lib, "unable to register for one_sided.slots_mr"); + status = UCC_ERR_NO_RESOURCE; + goto failed; + } + + /* this array holds local information about the slot status that was read from remote ranks */ + comm->one_sided.remote_slot_info = ucc_calloc(1, comm->commsize * ONE_SIDED_SLOTS_INFO_SIZE, + "one_sided.remote_slot_info"); + if (!comm->one_sided.remote_slot_info) { + tl_error(comm->lib, "unable to malloc for one_sided.remote_slot_info"); + status = UCC_ERR_NO_MEMORY; + goto failed; + } + comm->one_sided.remote_slot_info_mr = ibv_reg_mr(comm->ctx->pd, comm->one_sided.remote_slot_info, + comm->commsize * ONE_SIDED_SLOTS_INFO_SIZE, + IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_LOCAL_WRITE | + IBV_ACCESS_REMOTE_READ); + if (!comm->one_sided.remote_slot_info_mr) { + tl_error(comm->lib, "unable to register for one_sided.remote_slot_info_mr"); + status = UCC_ERR_NO_RESOURCE; + goto failed; + } + + comm->one_sided.info = ucc_calloc(1, sizeof(ucc_tl_mlx5_one_sided_reliable_team_info_t) * + comm->commsize, "one_sided.info"); + if (!comm->one_sided.info) { + tl_error(comm->lib, "unable to allocate mem for one_sided.info"); + status = UCC_ERR_NO_MEMORY; + goto failed; + } + + status = ucc_tl_mlx5_mcast_create_rc_qps(comm->ctx, comm); + if (UCC_OK != status) { + tl_error(comm->lib, "RC qp create failed"); + goto failed; + } + + /* below holds the remote addr/rkey to local slot field of all the + * processes used in async protocol */ + comm->one_sided.info[comm->rank].slot_mem.rkey = comm->one_sided.slots_mr->rkey; + comm->one_sided.info[comm->rank].slot_mem.remote_addr = (uint64_t)comm->one_sided.slots_buffer; + comm->one_sided.info[comm->rank].port_lid = comm->ctx->port_lid; + for (i = 0; i < comm->commsize; i++) { + comm->one_sided.info[comm->rank].rc_qp_num[i] = comm->mcast.rc_qp[i]->qp_num; + } + + tl_debug(comm->lib, "created the allgather reliability structures"); + + return UCC_OK; + +failed: + return status; +} + +static inline ucc_status_t ucc_tl_mlx5_mcast_one_sided_cleanup(ucc_tl_mlx5_mcast_coll_comm_t *comm) +{ + int j; + + if (comm->mcast.rc_qp != NULL) { + for (j=0; jcommsize; j++) { + if (comm->mcast.rc_qp[j] != NULL && ibv_destroy_qp(comm->mcast.rc_qp[j])) { + tl_error(comm->lib, "ibv_destroy_qp failed"); + return UCC_ERR_NO_RESOURCE; + } + comm->mcast.rc_qp[j] = NULL; + } + + ucc_free(comm->mcast.rc_qp); + comm->mcast.rc_qp = NULL; + } + + if (comm->srq != NULL && ibv_destroy_srq(comm->srq)) { + tl_error(comm->lib, "ibv_destroy_srq failed"); + return UCC_ERR_NO_RESOURCE; + } + comm->srq = NULL; + + if (comm->one_sided.slots_mr) { + ibv_dereg_mr(comm->one_sided.slots_mr); + comm->one_sided.slots_mr = 0; + } + + if (comm->one_sided.remote_slot_info_mr) { + ibv_dereg_mr(comm->one_sided.remote_slot_info_mr); + comm->one_sided.remote_slot_info_mr = 0; + } + + if (comm->one_sided.slots_buffer) { + ucc_free(comm->one_sided.slots_buffer); + comm->one_sided.slots_buffer = NULL; + } + + if (comm->one_sided.recvd_pkts_tracker) { + ucc_free(comm->one_sided.recvd_pkts_tracker); + comm->one_sided.recvd_pkts_tracker = NULL; + } + + if (comm->one_sided.sendbuf_memkey_list) { + ucc_free(comm->one_sided.sendbuf_memkey_list); + comm->one_sided.sendbuf_memkey_list = NULL; + } + + if (comm->one_sided.remote_slot_info) { + ucc_free(comm->one_sided.remote_slot_info); + comm->one_sided.remote_slot_info = NULL; + } + + if (comm->one_sided.info) { + ucc_free(comm->one_sided.info); + comm->one_sided.info = NULL; + } + + return UCC_OK; +} + +ucc_status_t ucc_tl_mlx5_mcast_one_sided_reliability_init(ucc_base_team_t *team) +{ + ucc_tl_mlx5_team_t *tl_team = ucc_derived_of(team, ucc_tl_mlx5_team_t); + ucc_tl_mlx5_mcast_coll_comm_t *comm = tl_team->mcast->mcast_comm; + ucc_status_t status = UCC_OK; + + status = ucc_tl_mlx5_mcast_one_sided_setup_reliability_buffers(team); + if (status != UCC_OK) { + tl_error(comm->lib, "setup reliablity resources failed"); + goto failed; + } + + /* TODO double check if ucc inplace allgather is working properly */ + status = comm->service_coll.allgather_post(comm->p2p_ctx, NULL /*inplace*/, comm->one_sided.info, + sizeof(ucc_tl_mlx5_one_sided_reliable_team_info_t), + &comm->one_sided.reliability_req); + if (UCC_OK != status) { + tl_error(comm->lib, "oob allgather failed during one-sided reliability init"); + goto failed; + } + + return status; + +failed: + if (UCC_OK != ucc_tl_mlx5_mcast_one_sided_cleanup(comm)) { + tl_error(comm->lib, "mcast one-sided reliablity resource cleanup failed"); + } + + return status; +} + +ucc_status_t ucc_tl_mlx5_mcast_one_sided_reliability_test(ucc_base_team_t *team) +{ + ucc_tl_mlx5_team_t *tl_team = ucc_derived_of(team, ucc_tl_mlx5_team_t); + ucc_status_t status = UCC_OK; + ucc_tl_mlx5_mcast_coll_comm_t *comm = tl_team->mcast->mcast_comm; + + /* check if the one sided config info is exchanged */ + status = comm->service_coll.coll_test(comm->one_sided.reliability_req); + if (UCC_OK != status) { + /* allgather is not completed yet */ + if (status < 0) { + tl_error(comm->lib, "one sided config info exchange failed"); + goto failed; + } + return status; + } + + /* we have all the info to make the reliable connections */ + status = ucc_tl_mlx5_mcast_modify_rc_qps(comm->ctx, comm); + if (UCC_OK != status) { + tl_error(comm->lib, "RC qp modify failed"); + goto failed; + } + +failed: + if (UCC_OK != ucc_tl_mlx5_mcast_one_sided_cleanup(comm)) { + tl_error(comm->lib, "mcast one-sided reliablity resource cleanup failed"); + } + + return status; +} + diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_one_sided_reliability.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_one_sided_reliability.h new file mode 100644 index 0000000000..c8f21f7dca --- /dev/null +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_one_sided_reliability.h @@ -0,0 +1,19 @@ +/** + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + + +#include "tl_mlx5.h" +#include "tl_mlx5_mcast_coll.h" +#include "coll_score/ucc_coll_score.h" +#include "tl_mlx5_mcast_helper.h" +#include "p2p/ucc_tl_mlx5_mcast_p2p.h" +#include "mcast/tl_mlx5_mcast_helper.h" +#include "mcast/tl_mlx5_mcast_service_coll.h" + + +ucc_status_t ucc_tl_mlx5_mcast_one_sided_reliability_init(ucc_base_team_t *team); + +ucc_status_t ucc_tl_mlx5_mcast_one_sided_reliability_test(ucc_base_team_t *team); diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_service_coll.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_service_coll.c new file mode 100644 index 0000000000..ea6c5dcbbd --- /dev/null +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_service_coll.c @@ -0,0 +1,83 @@ +/** + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include "mcast/tl_mlx5_mcast_service_coll.h" + +ucc_status_t ucc_tl_mlx5_mcast_service_bcast_post(void *arg, void *buf, size_t size, ucc_rank_t root, + ucc_service_coll_req_t **bcast_req) +{ + ucc_tl_mlx5_mcast_oob_p2p_context_t *ctx = (ucc_tl_mlx5_mcast_oob_p2p_context_t *)arg; + ucc_status_t status = UCC_OK; + ucc_team_t *team = ctx->base_team; + ucc_subset_t subset = ctx->subset; + ucc_service_coll_req_t *req = NULL; + + status = ucc_service_bcast(team, buf, size, root, subset, &req); + if (ucc_unlikely(UCC_OK != status)) { + tl_error(ctx->base_ctx->lib, "tl service mcast bcast failed"); + return status; + } + + *bcast_req = req; + + return status; +} + +ucc_status_t ucc_tl_mlx5_mcast_service_allgather_post(void *arg, void *sbuf, void *rbuf, size_t size, + ucc_service_coll_req_t **ag_req) +{ + ucc_tl_mlx5_mcast_oob_p2p_context_t *ctx = (ucc_tl_mlx5_mcast_oob_p2p_context_t *)arg; + ucc_status_t status = UCC_OK; + ucc_team_t *team = ctx->base_team; + ucc_subset_t subset = ctx->subset; + ucc_service_coll_req_t *req = NULL; + + status = ucc_service_allgather(team, sbuf, rbuf, size, subset, &req); + if (ucc_unlikely(UCC_OK != status)) { + tl_error(ctx->base_ctx->lib, "tl service mcast allgather failed"); + return status; + } + + *ag_req = req; + + return status; +} + +ucc_status_t ucc_tl_mlx5_mcast_service_barrier_post(void *arg, ucc_service_coll_req_t **barrier_req) +{ + ucc_tl_mlx5_mcast_oob_p2p_context_t *ctx = (ucc_tl_mlx5_mcast_oob_p2p_context_t *)arg; + ucc_status_t status = UCC_OK; + ucc_team_t *team = ctx->base_team; + ucc_subset_t subset = ctx->subset; + ucc_service_coll_req_t *req = NULL; + + status = ucc_service_allreduce(team, &ctx->tmp_buf, &ctx->tmp_buf, UCC_DT_INT8, 1, + UCC_OP_SUM, subset, &req); + if (ucc_unlikely(UCC_OK != status)) { + tl_error(ctx->base_ctx->lib, "tl service mcast barrier failed"); + return status; + } + + *barrier_req = req; + + return status; +} + +ucc_status_t ucc_tl_mlx5_mcast_service_coll_test(ucc_service_coll_req_t *req) +{ + ucc_status_t status = UCC_OK; + + status = ucc_service_coll_test(req); + + if (UCC_INPROGRESS != status) { + if (status < 0) { + ucc_error("oob service coll progress failed"); + } + ucc_service_coll_finalize(req); + } + + return status; +} diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_service_coll.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_service_coll.h new file mode 100644 index 0000000000..c42132ace2 --- /dev/null +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_service_coll.h @@ -0,0 +1,18 @@ +/** + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include "tl_mlx5_mcast_coll.h" + +ucc_status_t ucc_tl_mlx5_mcast_service_bcast_post(void *arg, void *buf, size_t size, ucc_rank_t root, + ucc_service_coll_req_t **bcast_req); + +ucc_status_t ucc_tl_mlx5_mcast_service_allgather_post(void *arg, void *sbuf, void *rbuf, size_t size, + ucc_service_coll_req_t **ag_req); + +ucc_status_t ucc_tl_mlx5_mcast_service_barrier_post(void *arg, ucc_service_coll_req_t **barrier_req); + +ucc_status_t ucc_tl_mlx5_mcast_service_coll_test(ucc_service_coll_req_t *req); + diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c index 61f6865669..402ff84472 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c @@ -11,40 +11,8 @@ #include "tl_mlx5_mcast_helper.h" #include "p2p/ucc_tl_mlx5_mcast_p2p.h" #include "mcast/tl_mlx5_mcast_helper.h" +#include "mcast/tl_mlx5_mcast_service_coll.h" -static ucc_status_t ucc_tl_mlx5_mcast_service_bcast_post(void *arg, void *buf, size_t size, ucc_rank_t root, - ucc_service_coll_req_t **bcast_req) -{ - ucc_tl_mlx5_mcast_oob_p2p_context_t *ctx = (ucc_tl_mlx5_mcast_oob_p2p_context_t *)arg; - ucc_status_t status = UCC_OK; - ucc_team_t *team = ctx->base_team; - ucc_subset_t subset = ctx->subset; - ucc_service_coll_req_t *req = NULL; - - status = ucc_service_bcast(team, buf, size, root, subset, &req); - if (ucc_unlikely(UCC_OK != status)) { - tl_error(ctx->base_ctx->lib, "tl service mcast bcast failed"); - return status; - } - - *bcast_req = req; - - return status; -} - -static ucc_status_t ucc_tl_mlx5_mcast_service_bcast_test(ucc_service_coll_req_t *req) -{ - ucc_status_t status = UCC_OK; - - status = ucc_service_coll_test(req); - - if (UCC_INPROGRESS != status) { - ucc_service_coll_finalize(req); - } - - return status; -} - ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context, ucc_tl_mlx5_mcast_team_t **mcast_team, ucc_tl_mlx5_mcast_context_t *ctx, @@ -113,8 +81,10 @@ ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context, ucc_list_head_init(&comm->bpool); ucc_list_head_init(&comm->pending_q); - comm->bcast_post = ucc_tl_mlx5_mcast_service_bcast_post; - comm->bcast_test = ucc_tl_mlx5_mcast_service_bcast_test; + comm->service_coll.bcast_post = ucc_tl_mlx5_mcast_service_bcast_post; + comm->service_coll.allgather_post = ucc_tl_mlx5_mcast_service_allgather_post; + comm->service_coll.barrier_post = ucc_tl_mlx5_mcast_service_barrier_post; + comm->service_coll.coll_test = ucc_tl_mlx5_mcast_service_coll_test; memcpy(&comm->params, conf_params, sizeof(*conf_params)); @@ -380,8 +350,8 @@ ucc_status_t ucc_tl_mlx5_mcast_team_test(ucc_base_team_t *team) data->status = UCC_ERR_NO_RESOURCE; } - status = comm->bcast_post(comm->p2p_ctx, data, sizeof(ucc_tl_mlx5_mcast_join_info_t), - 0, &comm->group_setup_info_req); + status = comm->service_coll.bcast_post(comm->p2p_ctx, data, sizeof(ucc_tl_mlx5_mcast_join_info_t), + 0, &comm->group_setup_info_req); if (UCC_OK != status) { tl_error(comm->lib, "unable to post bcast for group setup info"); ucc_free(comm->group_setup_info); @@ -403,7 +373,7 @@ ucc_status_t ucc_tl_mlx5_mcast_team_test(ucc_base_team_t *team) case TL_MLX5_TEAM_STATE_MCAST_GRP_BCAST_POST: { /* rank 0 polls bcast request and wait for its completion */ - status = comm->bcast_test(comm->group_setup_info_req); + status = comm->service_coll.coll_test(comm->group_setup_info_req); if (UCC_OK != status) { /* bcast is not completed yet */ if (status < 0) { @@ -472,8 +442,9 @@ ucc_status_t ucc_tl_mlx5_mcast_team_test(ucc_base_team_t *team) return UCC_ERR_NO_MEMORY; } - status = comm->bcast_post(comm->p2p_ctx, data, sizeof(ucc_tl_mlx5_mcast_join_info_t), - 0, &comm->group_setup_info_req); + status = comm->service_coll.bcast_post(comm->p2p_ctx, data, + sizeof(ucc_tl_mlx5_mcast_join_info_t), 0, + &comm->group_setup_info_req); if (UCC_OK != status) { tl_error(comm->lib, "unable to post bcast for group setup info"); ucc_free(data); @@ -489,7 +460,7 @@ ucc_status_t ucc_tl_mlx5_mcast_team_test(ucc_base_team_t *team) case TL_MLX5_TEAM_STATE_MCAST_GRP_BCAST_POST: { /* none rank 0 processes poll bcast request and wait for its completion */ - status = comm->bcast_test(comm->group_setup_info_req); + status = comm->service_coll.coll_test(comm->group_setup_info_req); if (UCC_OK != status) { /* bcast is not completed yet */ if (status < 0) {