Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TL/MLX5: one-sided mcast reliability init #980

Merged
merged 1 commit into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions src/components/tl/mlx5/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,23 @@ alltoall = \
alltoall/alltoall_inline.h \
alltoall/alltoall_coll.c

mcast = \
mcast/tl_mlx5_mcast_context.c \
mcast/tl_mlx5_mcast.h \
mcast/tl_mlx5_mcast_coll.c \
mcast/tl_mlx5_mcast_coll.h \
mcast/tl_mlx5_mcast_rcache.h \
mcast/tl_mlx5_mcast_rcache.c \
mcast/p2p/ucc_tl_mlx5_mcast_p2p.h \
mcast/p2p/ucc_tl_mlx5_mcast_p2p.c \
mcast/tl_mlx5_mcast_progress.h \
mcast/tl_mlx5_mcast_progress.c \
mcast/tl_mlx5_mcast_helper.h \
mcast/tl_mlx5_mcast_helper.c \
mcast = \
mcast/tl_mlx5_mcast_context.c \
mcast/tl_mlx5_mcast.h \
mcast/tl_mlx5_mcast_coll.c \
mcast/tl_mlx5_mcast_coll.h \
mcast/tl_mlx5_mcast_rcache.h \
mcast/tl_mlx5_mcast_rcache.c \
mcast/p2p/ucc_tl_mlx5_mcast_p2p.h \
mcast/p2p/ucc_tl_mlx5_mcast_p2p.c \
mcast/tl_mlx5_mcast_progress.h \
mcast/tl_mlx5_mcast_progress.c \
mcast/tl_mlx5_mcast_helper.h \
mcast/tl_mlx5_mcast_helper.c \
mcast/tl_mlx5_mcast_service_coll.h \
mcast/tl_mlx5_mcast_service_coll.c \
mcast/tl_mlx5_mcast_one_sided_reliability.h \
mcast/tl_mlx5_mcast_one_sided_reliability.c \
mcast/tl_mlx5_mcast_team.c

sources = \
Expand Down
78 changes: 78 additions & 0 deletions src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,81 @@ ucc_status_t ucc_tl_mlx5_mcast_p2p_recv_nb(void *dst, size_t size, ucc_rank_t

return status;
}

MamziB marked this conversation as resolved.
Show resolved Hide resolved
ucc_status_t ucc_tl_mlx5_one_sided_p2p_put(void* src, void* remote_addr, size_t length,
uint32_t lkey, uint32_t rkey, ucc_rank_t target_rank,
uint64_t wr_id, ucc_tl_mlx5_mcast_coll_comm_t *comm)
{
struct ibv_send_wr swr = {0};
struct ibv_sge ssg = {0};
struct ibv_send_wr *bad_wr;
int rc;

if (UINT32_MAX < length) {
tl_error(comm->lib, "msg too large for p2p put");
return UCC_ERR_NOT_SUPPORTED;
}

ssg.addr = (uint64_t)src;
ssg.length = (uint32_t)length;
MamziB marked this conversation as resolved.
Show resolved Hide resolved
ssg.lkey = lkey;
swr.sg_list = &ssg;
swr.num_sge = 1;

swr.opcode = IBV_WR_RDMA_WRITE;
swr.wr_id = wr_id;
swr.send_flags = IBV_SEND_SIGNALED;
swr.wr.rdma.remote_addr = (uint64_t)remote_addr;
swr.wr.rdma.rkey = rkey;
swr.next = NULL;

tl_trace(comm->lib, "RDMA WRITE to rank %d size length %ld remote address %p rkey %d lkey %d src %p",
target_rank, length, remote_addr, rkey, lkey, src);

if (0 != (rc = ibv_post_send(comm->mcast.rc_qp[target_rank], &swr, &bad_wr))) {
samnordmann marked this conversation as resolved.
Show resolved Hide resolved
tl_error(comm->lib, "RDMA Write failed rc %d rank %d remote addresss %p rkey %d",
rc, target_rank, remote_addr, rkey);
return UCC_ERR_NO_MESSAGE;
}

return UCC_OK;
}

ucc_status_t ucc_tl_mlx5_one_sided_p2p_get(void* src, void* remote_addr, size_t length,
uint32_t lkey, uint32_t rkey, ucc_rank_t target_rank,
uint64_t wr_id, ucc_tl_mlx5_mcast_coll_comm_t *comm)
{
struct ibv_send_wr swr = {0};
struct ibv_sge ssg = {0};
struct ibv_send_wr *bad_wr;
int rc;

if (UINT32_MAX < length) {
tl_error(comm->lib, "msg too large for p2p get");
return UCC_ERR_NOT_SUPPORTED;
}

ssg.addr = (uint64_t)src;
ssg.length = (uint32_t)length;
MamziB marked this conversation as resolved.
Show resolved Hide resolved
ssg.lkey = lkey;
swr.sg_list = &ssg;
swr.num_sge = 1;

swr.opcode = IBV_WR_RDMA_READ;
swr.wr_id = wr_id;
swr.send_flags = IBV_SEND_SIGNALED;
swr.wr.rdma.remote_addr = (uint64_t)remote_addr;
swr.wr.rdma.rkey = rkey;
swr.next = NULL;

tl_trace(comm->lib, "RDMA READ to rank %d size length %ld remote address %p rkey %d lkey %d src %p",
target_rank, length, remote_addr, rkey, lkey, src);

if (0 != (rc = ibv_post_send(comm->mcast.rc_qp[target_rank], &swr, &bad_wr))) {
samnordmann marked this conversation as resolved.
Show resolved Hide resolved
tl_error(comm->lib, "RDMA Read failed rc %d rank %d remote addresss %p rkey %d",
rc, target_rank, remote_addr, rkey);
return UCC_ERR_NO_MESSAGE;
}

return UCC_OK;
}
9 changes: 9 additions & 0 deletions src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,12 @@ ucc_status_t ucc_tl_mlx5_mcast_p2p_recv_nb(void* dst, size_t size, ucc_rank_t
rank, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t
*obj);

ucc_status_t ucc_tl_mlx5_one_sided_p2p_put(void* src, void* remote_addr, size_t length,
uint32_t lkey, uint32_t rkey, ucc_rank_t target_rank,
uint64_t wr_id, ucc_tl_mlx5_mcast_coll_comm_t *comm);

ucc_status_t ucc_tl_mlx5_one_sided_p2p_get(void* src, void* remote_addr, size_t length,
uint32_t lkey, uint32_t rkey, ucc_rank_t target_rank,
uint64_t wr_id, ucc_tl_mlx5_mcast_coll_comm_t *comm);
MamziB marked this conversation as resolved.
Show resolved Hide resolved

174 changes: 116 additions & 58 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@
#define DROP_THRESHOLD 1000
#define MAX_COMM_POW2 32

/* Allgather RDMA-based reliability designs */
#define ONE_SIDED_RELIABILITY_MAX_TEAM_SIZE 1024
MamziB marked this conversation as resolved.
Show resolved Hide resolved
#define ONE_SIDED_NO_RELIABILITY 0
samnordmann marked this conversation as resolved.
Show resolved Hide resolved
#define ONE_SIDED_SYNCHRONOUS_PROTO 1
#define ONE_SIDED_ASYNCHRONOUS_PROTO 2
#define ONE_SIDED_SLOTS_COUNT 2 /* number of memory slots during async design */
samnordmann marked this conversation as resolved.
Show resolved Hide resolved
#define ONE_SIDED_SLOTS_INFO_SIZE sizeof(uint32_t) /* size of metadata prepended to each slots in bytes */

enum {
MCAST_PROTO_EAGER, /* Internal staging buffers */
MCAST_PROTO_ZCOPY
Expand Down Expand Up @@ -136,6 +144,7 @@ typedef struct ucc_tl_mlx5_mcast_coll_context {
int ib_port;
int pkey_index;
int mtu;
uint16_t port_lid;
struct rdma_cm_id *id;
struct rdma_event_channel *channel;
ucc_mpool_t compl_objects_mp;
Expand Down Expand Up @@ -174,6 +183,10 @@ struct mcast_ctx {
struct ibv_ah *ah;
struct ibv_send_wr swr;
struct ibv_sge ssg;
// RC connection info for supporing one-sided based relibality
struct ibv_qp **rc_qp;
uint16_t *rc_lid;
union ibv_gid *rc_gid;
};

struct packet {
Expand All @@ -183,65 +196,109 @@ struct packet {
int comm_id;
};

typedef struct ucc_tl_mlx5_mcast_slot_mem_info {
uint64_t remote_addr;
uint32_t rkey;
} ucc_tl_mlx5_mcast_slot_mem_info_t;

typedef struct ucc_tl_mlx5_one_sided_reliable_team_info {
ucc_tl_mlx5_mcast_slot_mem_info_t slot_mem;
uint16_t port_lid;
uint32_t rc_qp_num[ONE_SIDED_RELIABILITY_MAX_TEAM_SIZE];
} ucc_tl_mlx5_one_sided_reliable_team_info_t;

typedef struct ucc_tl_mlx5_mcast_one_sided_reliability_comm {
/* all the info required for establishing a reliable connection as
* well as temp slots memkeys that all processes in the team need
* to be aware of*/
ucc_tl_mlx5_one_sided_reliable_team_info_t *info;
/* holds all the remote-addr/rkey of sendbuf from processes in the team
* used in sync design. it needs to be set during each mcast-allgather call
* after sendbuf registration */
ucc_tl_mlx5_mcast_slot_mem_info_t *sendbuf_memkey_list;
/* counter for each target recv packet */
uint32_t *recvd_pkts_tracker;
/* holds the remote targets' collective call counter. it is used to check
* if remote temp slot is ready for RDMA READ in async design */
uint32_t *remote_slot_info;
struct ibv_mr *remote_slot_info_mr;
int reliability_scheme_msg_threshold;
/* mem address and mem keys of the temp slots in async design */
char *slots_buffer;
struct ibv_mr *slots_mr;
/* size of a temp slot in async design */
int slot_size;
/* coll req that is used during the oob service calls */
ucc_service_coll_req_t *reliability_req;
} ucc_tl_mlx5_mcast_one_sided_reliability_comm_t;

typedef struct ucc_tl_mlx5_mcast_service_coll {
ucc_status_t (*bcast_post) (void*, void*, size_t, ucc_rank_t, ucc_service_coll_req_t**);
ucc_status_t (*allgather_post) (void*, void*, void*, size_t, ucc_service_coll_req_t**);
ucc_status_t (*barrier_post) (void*, ucc_service_coll_req_t**);
ucc_status_t (*coll_test) (ucc_service_coll_req_t*);
} ucc_tl_mlx5_mcast_service_coll_t;

typedef struct ucc_tl_mlx5_mcast_coll_comm {
struct pp_packet dummy_packet;
ucc_tl_mlx5_mcast_coll_context_t *ctx;
ucc_tl_mlx5_mcast_coll_comm_init_spec_t params;
ucc_tl_mlx5_mcast_p2p_interface_t p2p;
int tx;
struct ibv_cq *scq;
struct ibv_cq *rcq;
ucc_rank_t rank;
ucc_rank_t commsize;
char *grh_buf;
struct ibv_mr *grh_mr;
uint16_t mcast_lid;
union ibv_gid mgid;
unsigned max_inline;
size_t max_eager;
int max_per_packet;
int pending_send;
int pending_recv;
struct ibv_mr *pp_mr;
char *pp_buf;
struct pp_packet *pp;
uint32_t psn;
uint32_t last_psn;
uint32_t racks_n;
uint32_t sacks_n;
uint32_t last_acked;
uint32_t naks_n;
uint32_t child_n;
uint32_t parent_n;
int buf_n;
struct packet p2p_pkt[MAX_COMM_POW2];
struct packet p2p_spkt[MAX_COMM_POW2];
ucc_list_link_t bpool;
ucc_list_link_t pending_q;
struct mcast_ctx mcast;
int reliable_in_progress;
int recv_drop_packet_in_progress;
struct ibv_recv_wr *call_rwr;
struct ibv_sge *call_rsgs;
uint64_t timer;
int stalled;
int comm_id;
void *p2p_ctx;
ucc_base_lib_t *lib;
struct sockaddr_in6 mcast_addr;
ucc_rank_t parents[MAX_COMM_POW2];
ucc_rank_t children[MAX_COMM_POW2];
int nack_requests;
int nacks_counter;
int n_prep_reliable;
int n_mcast_reliable;
int wsize;
ucc_tl_mlx5_mcast_join_info_t *group_setup_info;
ucc_service_coll_req_t *group_setup_info_req;
ucc_status_t (*bcast_post) (void*, void*, size_t, ucc_rank_t, ucc_service_coll_req_t**);
ucc_status_t (*bcast_test) (ucc_service_coll_req_t*);
struct rdma_cm_event *event;
struct pp_packet *r_window[1]; // do not add any new variable after here
struct pp_packet dummy_packet;
ucc_tl_mlx5_mcast_coll_context_t *ctx;
ucc_tl_mlx5_mcast_coll_comm_init_spec_t params;
ucc_tl_mlx5_mcast_p2p_interface_t p2p;
int tx;
struct ibv_cq *scq;
struct ibv_cq *rcq;
struct ibv_srq *srq;
ucc_rank_t rank;
ucc_rank_t commsize;
char *grh_buf;
struct ibv_mr *grh_mr;
uint16_t mcast_lid;
union ibv_gid mgid;
unsigned max_inline;
size_t max_eager;
int max_per_packet;
int pending_send;
int pending_recv;
struct ibv_mr *pp_mr;
char *pp_buf;
struct pp_packet *pp;
uint32_t psn;
uint32_t last_psn;
uint32_t racks_n;
uint32_t sacks_n;
uint32_t last_acked;
uint32_t naks_n;
uint32_t child_n;
uint32_t parent_n;
int buf_n;
struct packet p2p_pkt[MAX_COMM_POW2];
struct packet p2p_spkt[MAX_COMM_POW2];
ucc_list_link_t bpool;
ucc_list_link_t pending_q;
struct mcast_ctx mcast;
int reliable_in_progress;
int recv_drop_packet_in_progress;
struct ibv_recv_wr *call_rwr;
struct ibv_sge *call_rsgs;
uint64_t timer;
int stalled;
int comm_id;
void *p2p_ctx;
ucc_base_lib_t *lib;
struct sockaddr_in6 mcast_addr;
ucc_rank_t parents[MAX_COMM_POW2];
ucc_rank_t children[MAX_COMM_POW2];
int nack_requests;
int nacks_counter;
int n_prep_reliable;
int n_mcast_reliable;
int wsize;
ucc_tl_mlx5_mcast_join_info_t *group_setup_info;
ucc_service_coll_req_t *group_setup_info_req;
ucc_tl_mlx5_mcast_service_coll_t service_coll;
struct rdma_cm_event *event;
ucc_tl_mlx5_mcast_one_sided_reliability_comm_t one_sided;
struct pp_packet *r_window[1]; // note: do not add any new variable after here
} ucc_tl_mlx5_mcast_coll_comm_t;

typedef struct ucc_tl_mlx5_mcast_team {
Expand Down Expand Up @@ -313,6 +370,7 @@ typedef struct ucc_tl_mlx5_mcast_oob_p2p_context {
ucc_rank_t my_team_rank;
ucc_subset_t subset;
ucc_base_lib_t *lib;
int tmp_buf;
} ucc_tl_mlx5_mcast_oob_p2p_context_t;

static inline struct pp_packet* ucc_tl_mlx5_mcast_buf_get_free(ucc_tl_mlx5_mcast_coll_comm_t* comm)
Expand Down
3 changes: 2 additions & 1 deletion src/components/tl/mlx5/mcast/tl_mlx5_mcast_context.c
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@ ucc_status_t ucc_tl_mlx5_mcast_context_init(ucc_tl_mlx5_mcast_context_t *cont
}
}

ctx->mtu = active_mtu;
ctx->mtu = active_mtu;
ctx->port_lid = port_attr.lid;

tl_debug(ctx->lib, "port active MTU is %d and port max MTU is %d",
active_mtu, max_mtu);
Expand Down
Loading
Loading