Skip to content

Commit

Permalink
TL/MLX5: one-sided mcast reliability init (openucx#980)
Browse files Browse the repository at this point in the history
  • Loading branch information
MamziB committed Jun 27, 2024
1 parent 039403c commit fd627dd
Show file tree
Hide file tree
Showing 12 changed files with 722 additions and 113 deletions.
30 changes: 17 additions & 13 deletions src/components/tl/mlx5/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,23 @@ alltoall = \
alltoall/alltoall_inline.h \
alltoall/alltoall_coll.c

mcast = \
mcast/tl_mlx5_mcast_context.c \
mcast/tl_mlx5_mcast.h \
mcast/tl_mlx5_mcast_coll.c \
mcast/tl_mlx5_mcast_coll.h \
mcast/tl_mlx5_mcast_rcache.h \
mcast/tl_mlx5_mcast_rcache.c \
mcast/p2p/ucc_tl_mlx5_mcast_p2p.h \
mcast/p2p/ucc_tl_mlx5_mcast_p2p.c \
mcast/tl_mlx5_mcast_progress.h \
mcast/tl_mlx5_mcast_progress.c \
mcast/tl_mlx5_mcast_helper.h \
mcast/tl_mlx5_mcast_helper.c \
mcast = \
mcast/tl_mlx5_mcast_context.c \
mcast/tl_mlx5_mcast.h \
mcast/tl_mlx5_mcast_coll.c \
mcast/tl_mlx5_mcast_coll.h \
mcast/tl_mlx5_mcast_rcache.h \
mcast/tl_mlx5_mcast_rcache.c \
mcast/p2p/ucc_tl_mlx5_mcast_p2p.h \
mcast/p2p/ucc_tl_mlx5_mcast_p2p.c \
mcast/tl_mlx5_mcast_progress.h \
mcast/tl_mlx5_mcast_progress.c \
mcast/tl_mlx5_mcast_helper.h \
mcast/tl_mlx5_mcast_helper.c \
mcast/tl_mlx5_mcast_service_coll.h \
mcast/tl_mlx5_mcast_service_coll.c \
mcast/tl_mlx5_mcast_one_sided_reliability.h \
mcast/tl_mlx5_mcast_one_sided_reliability.c \
mcast/tl_mlx5_mcast_team.c

sources = \
Expand Down
78 changes: 78 additions & 0 deletions src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,81 @@ ucc_status_t ucc_tl_mlx5_mcast_p2p_recv_nb(void *dst, size_t size, ucc_rank_t

return status;
}

ucc_status_t ucc_tl_mlx5_one_sided_p2p_put(void* src, void* remote_addr, size_t length,
uint32_t lkey, uint32_t rkey, ucc_rank_t target_rank,
uint64_t wr_id, ucc_tl_mlx5_mcast_coll_comm_t *comm)
{
struct ibv_send_wr swr = {0};
struct ibv_sge ssg = {0};
struct ibv_send_wr *bad_wr;
int rc;

if (UINT32_MAX < length) {
tl_error(comm->lib, "msg too large for p2p put");
return UCC_ERR_NOT_SUPPORTED;
}

ssg.addr = (uint64_t)src;
ssg.length = (uint32_t)length;
ssg.lkey = lkey;
swr.sg_list = &ssg;
swr.num_sge = 1;

swr.opcode = IBV_WR_RDMA_WRITE;
swr.wr_id = wr_id;
swr.send_flags = IBV_SEND_SIGNALED;
swr.wr.rdma.remote_addr = (uint64_t)remote_addr;
swr.wr.rdma.rkey = rkey;
swr.next = NULL;

tl_trace(comm->lib, "RDMA WRITE to rank %d size length %ld remote address %p rkey %d lkey %d src %p",
target_rank, length, remote_addr, rkey, lkey, src);

if (0 != (rc = ibv_post_send(comm->mcast.rc_qp[target_rank], &swr, &bad_wr))) {
tl_error(comm->lib, "RDMA Write failed rc %d rank %d remote addresss %p rkey %d",
rc, target_rank, remote_addr, rkey);
return UCC_ERR_NO_MESSAGE;
}

return UCC_OK;
}

ucc_status_t ucc_tl_mlx5_one_sided_p2p_get(void* src, void* remote_addr, size_t length,
uint32_t lkey, uint32_t rkey, ucc_rank_t target_rank,
uint64_t wr_id, ucc_tl_mlx5_mcast_coll_comm_t *comm)
{
struct ibv_send_wr swr = {0};
struct ibv_sge ssg = {0};
struct ibv_send_wr *bad_wr;
int rc;

if (UINT32_MAX < length) {
tl_error(comm->lib, "msg too large for p2p get");
return UCC_ERR_NOT_SUPPORTED;
}

ssg.addr = (uint64_t)src;
ssg.length = (uint32_t)length;
ssg.lkey = lkey;
swr.sg_list = &ssg;
swr.num_sge = 1;

swr.opcode = IBV_WR_RDMA_READ;
swr.wr_id = wr_id;
swr.send_flags = IBV_SEND_SIGNALED;
swr.wr.rdma.remote_addr = (uint64_t)remote_addr;
swr.wr.rdma.rkey = rkey;
swr.next = NULL;

tl_trace(comm->lib, "RDMA READ to rank %d size length %ld remote address %p rkey %d lkey %d src %p",
target_rank, length, remote_addr, rkey, lkey, src);

if (0 != (rc = ibv_post_send(comm->mcast.rc_qp[target_rank], &swr, &bad_wr))) {
tl_error(comm->lib, "RDMA Read failed rc %d rank %d remote addresss %p rkey %d",
rc, target_rank, remote_addr, rkey);
return UCC_ERR_NO_MESSAGE;
}

return UCC_OK;
}
9 changes: 9 additions & 0 deletions src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,12 @@ ucc_status_t ucc_tl_mlx5_mcast_p2p_recv_nb(void* dst, size_t size, ucc_rank_t
rank, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t
*obj);

ucc_status_t ucc_tl_mlx5_one_sided_p2p_put(void* src, void* remote_addr, size_t length,
uint32_t lkey, uint32_t rkey, ucc_rank_t target_rank,
uint64_t wr_id, ucc_tl_mlx5_mcast_coll_comm_t *comm);

ucc_status_t ucc_tl_mlx5_one_sided_p2p_get(void* src, void* remote_addr, size_t length,
uint32_t lkey, uint32_t rkey, ucc_rank_t target_rank,
uint64_t wr_id, ucc_tl_mlx5_mcast_coll_comm_t *comm);

174 changes: 116 additions & 58 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@
#define DROP_THRESHOLD 1000
#define MAX_COMM_POW2 32

/* Allgather RDMA-based reliability designs */
#define ONE_SIDED_RELIABILITY_MAX_TEAM_SIZE 1024
#define ONE_SIDED_NO_RELIABILITY 0
#define ONE_SIDED_SYNCHRONOUS_PROTO 1
#define ONE_SIDED_ASYNCHRONOUS_PROTO 2
#define ONE_SIDED_SLOTS_COUNT 2 /* number of memory slots during async design */
#define ONE_SIDED_SLOTS_INFO_SIZE sizeof(uint32_t) /* size of metadata prepended to each slots in bytes */

enum {
MCAST_PROTO_EAGER, /* Internal staging buffers */
MCAST_PROTO_ZCOPY
Expand Down Expand Up @@ -136,6 +144,7 @@ typedef struct ucc_tl_mlx5_mcast_coll_context {
int ib_port;
int pkey_index;
int mtu;
uint16_t port_lid;
struct rdma_cm_id *id;
struct rdma_event_channel *channel;
ucc_mpool_t compl_objects_mp;
Expand Down Expand Up @@ -174,6 +183,10 @@ struct mcast_ctx {
struct ibv_ah *ah;
struct ibv_send_wr swr;
struct ibv_sge ssg;
// RC connection info for supporing one-sided based relibality
struct ibv_qp **rc_qp;
uint16_t *rc_lid;
union ibv_gid *rc_gid;
};

struct packet {
Expand All @@ -183,65 +196,109 @@ struct packet {
int comm_id;
};

typedef struct ucc_tl_mlx5_mcast_slot_mem_info {
uint64_t remote_addr;
uint32_t rkey;
} ucc_tl_mlx5_mcast_slot_mem_info_t;

typedef struct ucc_tl_mlx5_one_sided_reliable_team_info {
ucc_tl_mlx5_mcast_slot_mem_info_t slot_mem;
uint16_t port_lid;
uint32_t rc_qp_num[ONE_SIDED_RELIABILITY_MAX_TEAM_SIZE];
} ucc_tl_mlx5_one_sided_reliable_team_info_t;

typedef struct ucc_tl_mlx5_mcast_one_sided_reliability_comm {
/* all the info required for establishing a reliable connection as
* well as temp slots memkeys that all processes in the team need
* to be aware of*/
ucc_tl_mlx5_one_sided_reliable_team_info_t *info;
/* holds all the remote-addr/rkey of sendbuf from processes in the team
* used in sync design. it needs to be set during each mcast-allgather call
* after sendbuf registration */
ucc_tl_mlx5_mcast_slot_mem_info_t *sendbuf_memkey_list;
/* counter for each target recv packet */
uint32_t *recvd_pkts_tracker;
/* holds the remote targets' collective call counter. it is used to check
* if remote temp slot is ready for RDMA READ in async design */
uint32_t *remote_slot_info;
struct ibv_mr *remote_slot_info_mr;
int reliability_scheme_msg_threshold;
/* mem address and mem keys of the temp slots in async design */
char *slots_buffer;
struct ibv_mr *slots_mr;
/* size of a temp slot in async design */
int slot_size;
/* coll req that is used during the oob service calls */
ucc_service_coll_req_t *reliability_req;
} ucc_tl_mlx5_mcast_one_sided_reliability_comm_t;

typedef struct ucc_tl_mlx5_mcast_service_coll {
ucc_status_t (*bcast_post) (void*, void*, size_t, ucc_rank_t, ucc_service_coll_req_t**);
ucc_status_t (*allgather_post) (void*, void*, void*, size_t, ucc_service_coll_req_t**);
ucc_status_t (*barrier_post) (void*, ucc_service_coll_req_t**);
ucc_status_t (*coll_test) (ucc_service_coll_req_t*);
} ucc_tl_mlx5_mcast_service_coll_t;

typedef struct ucc_tl_mlx5_mcast_coll_comm {
struct pp_packet dummy_packet;
ucc_tl_mlx5_mcast_coll_context_t *ctx;
ucc_tl_mlx5_mcast_coll_comm_init_spec_t params;
ucc_tl_mlx5_mcast_p2p_interface_t p2p;
int tx;
struct ibv_cq *scq;
struct ibv_cq *rcq;
ucc_rank_t rank;
ucc_rank_t commsize;
char *grh_buf;
struct ibv_mr *grh_mr;
uint16_t mcast_lid;
union ibv_gid mgid;
unsigned max_inline;
size_t max_eager;
int max_per_packet;
int pending_send;
int pending_recv;
struct ibv_mr *pp_mr;
char *pp_buf;
struct pp_packet *pp;
uint32_t psn;
uint32_t last_psn;
uint32_t racks_n;
uint32_t sacks_n;
uint32_t last_acked;
uint32_t naks_n;
uint32_t child_n;
uint32_t parent_n;
int buf_n;
struct packet p2p_pkt[MAX_COMM_POW2];
struct packet p2p_spkt[MAX_COMM_POW2];
ucc_list_link_t bpool;
ucc_list_link_t pending_q;
struct mcast_ctx mcast;
int reliable_in_progress;
int recv_drop_packet_in_progress;
struct ibv_recv_wr *call_rwr;
struct ibv_sge *call_rsgs;
uint64_t timer;
int stalled;
int comm_id;
void *p2p_ctx;
ucc_base_lib_t *lib;
struct sockaddr_in6 mcast_addr;
ucc_rank_t parents[MAX_COMM_POW2];
ucc_rank_t children[MAX_COMM_POW2];
int nack_requests;
int nacks_counter;
int n_prep_reliable;
int n_mcast_reliable;
int wsize;
ucc_tl_mlx5_mcast_join_info_t *group_setup_info;
ucc_service_coll_req_t *group_setup_info_req;
ucc_status_t (*bcast_post) (void*, void*, size_t, ucc_rank_t, ucc_service_coll_req_t**);
ucc_status_t (*bcast_test) (ucc_service_coll_req_t*);
struct rdma_cm_event *event;
struct pp_packet *r_window[1]; // do not add any new variable after here
struct pp_packet dummy_packet;
ucc_tl_mlx5_mcast_coll_context_t *ctx;
ucc_tl_mlx5_mcast_coll_comm_init_spec_t params;
ucc_tl_mlx5_mcast_p2p_interface_t p2p;
int tx;
struct ibv_cq *scq;
struct ibv_cq *rcq;
struct ibv_srq *srq;
ucc_rank_t rank;
ucc_rank_t commsize;
char *grh_buf;
struct ibv_mr *grh_mr;
uint16_t mcast_lid;
union ibv_gid mgid;
unsigned max_inline;
size_t max_eager;
int max_per_packet;
int pending_send;
int pending_recv;
struct ibv_mr *pp_mr;
char *pp_buf;
struct pp_packet *pp;
uint32_t psn;
uint32_t last_psn;
uint32_t racks_n;
uint32_t sacks_n;
uint32_t last_acked;
uint32_t naks_n;
uint32_t child_n;
uint32_t parent_n;
int buf_n;
struct packet p2p_pkt[MAX_COMM_POW2];
struct packet p2p_spkt[MAX_COMM_POW2];
ucc_list_link_t bpool;
ucc_list_link_t pending_q;
struct mcast_ctx mcast;
int reliable_in_progress;
int recv_drop_packet_in_progress;
struct ibv_recv_wr *call_rwr;
struct ibv_sge *call_rsgs;
uint64_t timer;
int stalled;
int comm_id;
void *p2p_ctx;
ucc_base_lib_t *lib;
struct sockaddr_in6 mcast_addr;
ucc_rank_t parents[MAX_COMM_POW2];
ucc_rank_t children[MAX_COMM_POW2];
int nack_requests;
int nacks_counter;
int n_prep_reliable;
int n_mcast_reliable;
int wsize;
ucc_tl_mlx5_mcast_join_info_t *group_setup_info;
ucc_service_coll_req_t *group_setup_info_req;
ucc_tl_mlx5_mcast_service_coll_t service_coll;
struct rdma_cm_event *event;
ucc_tl_mlx5_mcast_one_sided_reliability_comm_t one_sided;
struct pp_packet *r_window[1]; // note: do not add any new variable after here
} ucc_tl_mlx5_mcast_coll_comm_t;

typedef struct ucc_tl_mlx5_mcast_team {
Expand Down Expand Up @@ -313,6 +370,7 @@ typedef struct ucc_tl_mlx5_mcast_oob_p2p_context {
ucc_rank_t my_team_rank;
ucc_subset_t subset;
ucc_base_lib_t *lib;
int tmp_buf;
} ucc_tl_mlx5_mcast_oob_p2p_context_t;

static inline struct pp_packet* ucc_tl_mlx5_mcast_buf_get_free(ucc_tl_mlx5_mcast_coll_comm_t* comm)
Expand Down
3 changes: 2 additions & 1 deletion src/components/tl/mlx5/mcast/tl_mlx5_mcast_context.c
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@ ucc_status_t ucc_tl_mlx5_mcast_context_init(ucc_tl_mlx5_mcast_context_t *cont
}
}

ctx->mtu = active_mtu;
ctx->mtu = active_mtu;
ctx->port_lid = port_attr.lid;

tl_debug(ctx->lib, "port active MTU is %d and port max MTU is %d",
active_mtu, max_mtu);
Expand Down
Loading

0 comments on commit fd627dd

Please sign in to comment.