diff --git a/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c b/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c index ea57bfa89c..f2019b4ae7 100644 --- a/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c +++ b/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c @@ -27,6 +27,7 @@ void ucc_tl_mlx5_mcast_completion_cb(void* context, ucc_status_t status) //NOLIN static inline ucc_status_t ucc_tl_mlx5_mcast_do_p2p_bcast_nb(void *buf, size_t len, ucc_rank_t my_team_rank, ucc_rank_t dest, + ucc_memory_type_t mem_type, ucc_team_h team, ucc_coll_callback_t *callback, ucc_coll_req_h *p2p_req, int is_send, ucc_base_lib_t *lib) @@ -41,7 +42,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_p2p_bcast_nb(void *buf, size_t args.src.info.buffer = buf; args.src.info.count = len; args.src.info.datatype = UCC_DT_INT8; - args.src.info.mem_type = UCC_MEMORY_TYPE_HOST; + args.src.info.mem_type = mem_type; args.root = is_send ? my_team_rank : dest; args.cb.cb = callback->cb; args.cb.data = callback->data; @@ -69,25 +70,27 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_p2p_bcast_nb(void *buf, size_t } static inline ucc_status_t do_send_nb(void *sbuf, size_t len, ucc_rank_t - my_team_rank, ucc_rank_t dest, ucc_team_h team, + my_team_rank, ucc_rank_t dest, + ucc_memory_type_t mem_type, ucc_team_h team, ucc_coll_callback_t *callback, ucc_coll_req_h *req, ucc_base_lib_t *lib) { - return ucc_tl_mlx5_mcast_do_p2p_bcast_nb(sbuf, len, my_team_rank, dest, + return ucc_tl_mlx5_mcast_do_p2p_bcast_nb(sbuf, len, my_team_rank, dest, mem_type, team, callback, req, 1, lib); } static inline ucc_status_t do_recv_nb(void *rbuf, size_t len, ucc_rank_t - my_team_rank, ucc_rank_t dest, ucc_team_h team, + my_team_rank, ucc_rank_t dest, + ucc_memory_type_t mem_type, ucc_team_h team, ucc_coll_callback_t *callback, ucc_coll_req_h *req, ucc_base_lib_t *lib) { - return ucc_tl_mlx5_mcast_do_p2p_bcast_nb(rbuf, len, my_team_rank, dest, + return ucc_tl_mlx5_mcast_do_p2p_bcast_nb(rbuf, len, my_team_rank, dest, mem_type, team, callback, req, 0, lib); } ucc_status_t ucc_tl_mlx5_mcast_p2p_send_nb(void* src, size_t size, ucc_rank_t - rank, void *context, + rank, ucc_memory_type_t mem_type, void *context, ucc_tl_mlx5_mcast_p2p_completion_obj_t *obj) { @@ -103,7 +106,7 @@ ucc_status_t ucc_tl_mlx5_mcast_p2p_send_nb(void* src, size_t size, ucc_rank_t callback.data = obj; tl_trace(oob_p2p_ctx->lib, "P2P: SEND to %d Msg Size %ld", rank, size); - status = do_send_nb(src, size, my_team_rank, rank, team, &callback, &req, oob_p2p_ctx->lib); + status = do_send_nb(src, size, my_team_rank, rank, mem_type, team, &callback, &req, oob_p2p_ctx->lib); if (status < 0) { tl_error(oob_p2p_ctx->lib, "nonblocking p2p send failed"); @@ -114,7 +117,7 @@ ucc_status_t ucc_tl_mlx5_mcast_p2p_send_nb(void* src, size_t size, ucc_rank_t } ucc_status_t ucc_tl_mlx5_mcast_p2p_recv_nb(void *dst, size_t size, ucc_rank_t - rank, void *context, + rank, ucc_memory_type_t mem_type, void *context, ucc_tl_mlx5_mcast_p2p_completion_obj_t *obj) { @@ -130,7 +133,7 @@ ucc_status_t ucc_tl_mlx5_mcast_p2p_recv_nb(void *dst, size_t size, ucc_rank_t callback.data = obj; tl_trace(oob_p2p_ctx->lib, "P2P: RECV to %d Msg Size %ld", rank, size); - status = do_recv_nb(dst, size, my_team_rank, rank, team, &callback, &req, oob_p2p_ctx->lib); + status = do_recv_nb(dst, size, my_team_rank, rank, mem_type, team, &callback, &req, oob_p2p_ctx->lib); if (status < 0) { tl_error(oob_p2p_ctx->lib, "nonblocking p2p recv failed"); diff --git a/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.h b/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.h index 48b6bad4c5..32427cb0b9 100644 --- a/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.h +++ b/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.h @@ -8,12 +8,12 @@ #include "components/tl/mlx5/mcast/tl_mlx5_mcast.h" ucc_status_t ucc_tl_mlx5_mcast_p2p_send_nb(void* src, size_t size, ucc_rank_t - rank, void *context, + rank, ucc_memory_type_t mem_type, void *context, ucc_tl_mlx5_mcast_p2p_completion_obj_t *obj); ucc_status_t ucc_tl_mlx5_mcast_p2p_recv_nb(void* dst, size_t size, ucc_rank_t - rank, void *context, + rank, ucc_memory_type_t mem_type, void *context, ucc_tl_mlx5_mcast_p2p_completion_obj_t *obj); diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h index 4e1caebf2a..9d9cee6899 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h @@ -31,7 +31,7 @@ #define DEF_SL 0 #define DEF_SRC_PATH_BITS 0 #define GRH_LENGTH 40 -#define DROP_THRESHOLD 1000 +#define DROP_THRESHOLD 10000 #define MAX_COMM_POW2 32 /* Allgather RDMA-based reliability designs */ @@ -75,12 +75,12 @@ typedef struct ucc_tl_mlx5_mcast_p2p_completion_obj { typedef int (*ucc_tl_mlx5_mcast_p2p_wait_cb_fn_t)(void *wait_arg); typedef ucc_status_t (*ucc_tl_mlx5_mcast_p2p_send_nb_fn_t)(void* src, size_t size, - ucc_rank_t rank, void *context, + ucc_rank_t rank, ucc_memory_type_t mem_type, void *context, ucc_tl_mlx5_mcast_p2p_completion_obj_t *compl_obj); typedef ucc_status_t (*ucc_tl_mlx5_mcast_p2p_recv_nb_fn_t)(void* src, size_t size, - ucc_rank_t rank, void *context, + ucc_rank_t rank, ucc_memory_type_t mem_type, void *context, ucc_tl_mlx5_mcast_p2p_completion_obj_t *compl_obj); typedef struct ucc_tl_mlx5_mcast_p2p_interface { diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c index a6725698ce..8289cc4339 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c @@ -182,7 +182,9 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_prepare_bcast(void* buf, size_t siz req->am_root = (root == comm->rank); req->mr = comm->pp_mr; req->rreg = NULL; - req->proto = (req->length < comm->max_eager) ? MCAST_PROTO_EAGER : MCAST_PROTO_ZCOPY; + /* cost of copy is too high in cuda buffers */ + req->proto = (req->length < comm->max_eager && !comm->cuda_mem_enabled) ? + MCAST_PROTO_EAGER : MCAST_PROTO_ZCOPY; status = ucc_tl_mlx5_mcast_prepare_reliable(comm, req, req->root); if (ucc_unlikely(UCC_OK != status)) { @@ -283,8 +285,8 @@ void ucc_tl_mlx5_mcast_collective_progress(ucc_coll_task_t *coll_task) } } -ucc_status_t ucc_tl_mlx5_mcast_check_memory_type_cap(ucc_base_coll_args_t *coll_args, - ucc_base_team_t *team) +static inline ucc_status_t ucc_tl_mlx5_mcast_check_memory_type_cap(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team) { ucc_tl_mlx5_team_t *mlx5_team = ucc_derived_of(team, ucc_tl_mlx5_team_t); ucc_tl_mlx5_mcast_coll_comm_t *comm = mlx5_team->mcast->mcast_comm; @@ -300,6 +302,33 @@ ucc_status_t ucc_tl_mlx5_mcast_check_memory_type_cap(ucc_base_coll_args_t *coll_ return UCC_ERR_NO_RESOURCE; } +ucc_status_t ucc_tl_mlx5_mcast_check_support(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team) +{ + ucc_coll_args_t *args = &coll_args->args; + int buf_size = ucc_dt_size(args->src.info.datatype) * args->src.info.count; + + if (UCC_COLL_ARGS_ACTIVE_SET(args)) { + tl_trace(team->context->lib, "mcast bcast not supported for active sets"); + return UCC_ERR_NOT_SUPPORTED; + } + + if (UCC_OK != ucc_tl_mlx5_mcast_check_memory_type_cap(coll_args, team)) { + tl_trace(team->context->lib, "mcast bcast not compatible with this memory type"); + return UCC_ERR_NOT_SUPPORTED; + } + + if (args->src.info.mem_type == UCC_MEMORY_TYPE_CUDA && + buf_size > 4000) { + /* for large messages (more than one mtu) we need zero-copy design which + * is not implemented yet */ + tl_trace(team->context->lib, "mcast cuda bcast not supported for large messages"); + return UCC_ERR_NOT_IMPLEMENTED; + } + + return UCC_OK; +} + ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task) { task->super.post = ucc_tl_mlx5_mcast_bcast_start; diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h index a5725915f7..f34e8827f4 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h @@ -14,6 +14,6 @@ ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task); ucc_status_t ucc_tl_mlx5_mcast_test(ucc_tl_mlx5_mcast_coll_req_t* _req); -ucc_status_t ucc_tl_mlx5_mcast_check_memory_type_cap(ucc_base_coll_args_t *coll_args, - ucc_base_team_t *team); +ucc_status_t ucc_tl_mlx5_mcast_check_support(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team); #endif diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c index f506137f3d..41b6ca14f9 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c @@ -30,6 +30,7 @@ static ucc_status_t ucc_tl_mlx5_mcast_reliability_send_completion(ucc_tl_mlx5_mc comm->nack_requests--; status = comm->params.p2p_iface.recv_nb(&comm->p2p_pkt[pkt_id], sizeof(struct packet), comm->p2p_pkt[pkt_id].from, + UCC_MEMORY_TYPE_HOST, comm->p2p_ctx, GET_COMPL_OBJ(comm, ucc_tl_mlx5_mcast_recv_completion, pkt_id, NULL)); if (status < 0) { @@ -48,6 +49,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_resend_packet_reliable(ucc_tl_mlx5_ uint32_t psn = comm->p2p_pkt[p2p_pkt_id].psn; struct pp_packet *pp = comm->r_window[psn % comm->wsize]; ucc_status_t status; + ucc_memory_type_t mem_type; ucc_assert(pp->psn == psn); ucc_assert(comm->p2p_pkt[p2p_pkt_id].type == MCAST_P2P_NEED_NACK_SEND); @@ -58,8 +60,14 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_resend_packet_reliable(ucc_tl_mlx5_ comm->comm_id, comm->rank, comm->p2p_pkt[p2p_pkt_id].from, psn, pp->context, comm->nack_requests); + if (comm->cuda_mem_enabled) { + mem_type = UCC_MEMORY_TYPE_CUDA; + } else { + mem_type = UCC_MEMORY_TYPE_HOST; + } + status = comm->params.p2p_iface.send_nb((void*) (pp->context ? pp->context : pp->buf), - pp->length, comm->p2p_pkt[p2p_pkt_id].from, + pp->length, comm->p2p_pkt[p2p_pkt_id].from, mem_type, comm->p2p_ctx, GET_COMPL_OBJ(comm, ucc_tl_mlx5_mcast_reliability_send_completion, NULL, p2p_pkt_id)); if (status < 0) { @@ -138,11 +146,25 @@ static ucc_status_t ucc_tl_mlx5_mcast_recv_data_completion(ucc_tl_mlx5_mcast_p2p struct pp_packet *pp = (struct pp_packet *)obj->data[1]; ucc_tl_mlx5_mcast_coll_req_t *req = (ucc_tl_mlx5_mcast_coll_req_t *)obj->data[2]; void *dest; + ucc_memory_type_t mem_type; tl_trace(comm->lib, "[comm %d, rank %d] Recved data psn %d", comm->comm_id, comm->rank, pp->psn); dest = req->ptr + PSN_TO_RECV_OFFSET(pp->psn, req, comm); - memcpy(dest, (void*) pp->buf, pp->length); + + if (comm->cuda_mem_enabled) { + mem_type = UCC_MEMORY_TYPE_CUDA; + } else { + mem_type = UCC_MEMORY_TYPE_HOST; + } + + status = ucc_mc_memcpy(dest, (void*) pp->buf, pp->length, + mem_type, mem_type); + if (ucc_unlikely(status != UCC_OK)) { + tl_error(comm->lib, "failed to copy buffer"); + return status; + } + req->to_recv--; comm->r_window[pp->psn % comm->wsize] = pp; @@ -165,6 +187,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_reliable_send_NACK(ucc_tl_mlx5_mcas struct pp_packet *pp; ucc_rank_t parent; struct packet *p; + ucc_memory_type_t mem_type; p = ucc_calloc(1, sizeof(struct packet)); p->type = MCAST_P2P_NACK; @@ -176,7 +199,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_reliable_send_NACK(ucc_tl_mlx5_mcas comm->nacks_counter++; - status = comm->params.p2p_iface.send_nb(p, sizeof(struct packet), parent, + status = comm->params.p2p_iface.send_nb(p, sizeof(struct packet), parent, UCC_MEMORY_TYPE_HOST, comm->p2p_ctx, GET_COMPL_OBJ(comm, ucc_tl_mlx5_mcast_reliability_send_completion, p, UINT_MAX)); if (status < 0) { @@ -193,8 +216,14 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_reliable_send_NACK(ucc_tl_mlx5_mcas comm->recv_drop_packet_in_progress = true; + if (comm->cuda_mem_enabled) { + mem_type = UCC_MEMORY_TYPE_CUDA; + } else { + mem_type = UCC_MEMORY_TYPE_HOST; + } + status = comm->params.p2p_iface.recv_nb((void*) pp->buf, - pp->length, parent, + pp->length, parent, mem_type, comm->p2p_ctx, GET_COMPL_OBJ(comm, ucc_tl_mlx5_mcast_recv_data_completion, pp, req)); if (status < 0) { @@ -225,7 +254,7 @@ ucc_status_t ucc_tl_mlx5_mcast_reliable_send(ucc_tl_mlx5_mcast_coll_comm_t *comm comm->rank, parent, comm->parent_n, comm->psn); status = comm->params.p2p_iface.send_nb(&comm->p2p_spkt[i], - sizeof(struct packet), parent, + sizeof(struct packet), parent, UCC_MEMORY_TYPE_HOST, comm->p2p_ctx, GET_COMPL_OBJ(comm, ucc_tl_mlx5_mcast_send_completion, i, NULL)); if (status < 0) { @@ -325,7 +354,7 @@ ucc_status_t ucc_tl_mlx5_mcast_prepare_reliable(ucc_tl_mlx5_mcast_coll_comm_t *c comm->rank, child, comm->child_n, comm->psn); status = comm->params.p2p_iface.recv_nb(&comm->p2p_pkt[comm->child_n - 1], - sizeof(struct packet), child, + sizeof(struct packet), child, UCC_MEMORY_TYPE_HOST, comm->p2p_ctx, GET_COMPL_OBJ(comm, ucc_tl_mlx5_mcast_recv_completion, comm->child_n - 1, req)); if (status < 0) { @@ -369,8 +398,8 @@ ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *com ucc_tl_mlx5_mcast_coll_req_t *req, struct pp_packet* pp) { - ucc_status_t status = UCC_OK; - void *dest; + ucc_status_t status = UCC_OK; + void *dest; ucc_memory_type_t mem_type; ucc_assert(pp->psn >= req->start_psn && pp->psn < req->start_psn + req->num_packets); diff --git a/src/components/tl/mlx5/tl_mlx5_coll.c b/src/components/tl/mlx5/tl_mlx5_coll.c index a8add9715e..909b457325 100644 --- a/src/components/tl/mlx5/tl_mlx5_coll.c +++ b/src/components/tl/mlx5/tl_mlx5_coll.c @@ -14,19 +14,13 @@ ucc_status_t ucc_tl_mlx5_bcast_mcast_init(ucc_base_coll_args_t *coll_args, { ucc_status_t status = UCC_OK; ucc_tl_mlx5_task_t *task = NULL; - - if (UCC_COLL_ARGS_ACTIVE_SET(&coll_args->args)) { - tl_trace(team->context->lib, "mcast bcast not supported for active sets"); - return UCC_ERR_NOT_SUPPORTED; - } - if (UCC_OK != ucc_tl_mlx5_mcast_check_memory_type_cap(coll_args, team)) { - tl_trace(team->context->lib, "mcast bcast not compatible with this memory type"); - return UCC_ERR_NOT_SUPPORTED; + status = ucc_tl_mlx5_mcast_check_support(coll_args, team); + if (UCC_OK != status) { + return status; } task = ucc_tl_mlx5_get_task(coll_args, team); - if (ucc_unlikely(!task)) { return UCC_ERR_NO_MEMORY; }