From f01fd27e0452052be76454d995c1b60497a4782f Mon Sep 17 00:00:00 2001 From: Mamzi Bayatpour <77160721+MamziB@users.noreply.github.com> Date: Mon, 11 Mar 2024 07:32:04 -0700 Subject: [PATCH] TL/MLX5: adding mcast bcast algo (#929) --- src/components/tl/mlx5/Makefile.am | 1 + .../tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c | 2 +- .../tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.h | 2 +- src/components/tl/mlx5/mcast/tl_mlx5_mcast.h | 18 +- .../tl/mlx5/mcast/tl_mlx5_mcast_coll.c | 257 +++++++++++- .../tl/mlx5/mcast/tl_mlx5_mcast_coll.h | 2 +- .../tl/mlx5/mcast/tl_mlx5_mcast_context.c | 2 +- .../tl/mlx5/mcast/tl_mlx5_mcast_helper.c | 18 +- .../tl/mlx5/mcast/tl_mlx5_mcast_helper.h | 4 +- .../tl/mlx5/mcast/tl_mlx5_mcast_progress.c | 378 ++++++++++++++++++ .../tl/mlx5/mcast/tl_mlx5_mcast_progress.h | 16 +- .../tl/mlx5/mcast/tl_mlx5_mcast_rcache.c | 9 +- .../tl/mlx5/mcast/tl_mlx5_mcast_rcache.h | 6 +- .../tl/mlx5/mcast/tl_mlx5_mcast_team.c | 12 +- src/components/tl/mlx5/tl_mlx5_coll.c | 9 +- src/components/tl/mlx5/tl_mlx5_coll.h | 1 + 16 files changed, 673 insertions(+), 64 deletions(-) create mode 100644 src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c diff --git a/src/components/tl/mlx5/Makefile.am b/src/components/tl/mlx5/Makefile.am index 4236b0b9cd..01b94cfa1d 100644 --- a/src/components/tl/mlx5/Makefile.am +++ b/src/components/tl/mlx5/Makefile.am @@ -22,6 +22,7 @@ mcast = \ mcast/p2p/ucc_tl_mlx5_mcast_p2p.h \ mcast/p2p/ucc_tl_mlx5_mcast_p2p.c \ mcast/tl_mlx5_mcast_progress.h \ + mcast/tl_mlx5_mcast_progress.c \ mcast/tl_mlx5_mcast_helper.h \ mcast/tl_mlx5_mcast_helper.c \ mcast/tl_mlx5_mcast_team.c diff --git a/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c b/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c index c8dca25fc3..d5c5d9dfb4 100644 --- a/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c +++ b/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ diff --git a/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.h b/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.h index 6e19e59dde..e82f7546a7 100644 --- a/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.h +++ b/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h index 66da5ff474..1d08f99edf 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -64,14 +64,14 @@ typedef struct ucc_tl_mlx5_mcast_p2p_completion_obj { typedef int (*ucc_tl_mlx5_mcast_p2p_wait_cb_fn_t)(void *wait_arg); -typedef int (*ucc_tl_mlx5_mcast_p2p_send_nb_fn_t)(void* src, size_t size, - ucc_rank_t rank, void *context, - ucc_tl_mlx5_mcast_p2p_completion_obj_t *compl_obj); +typedef ucc_status_t (*ucc_tl_mlx5_mcast_p2p_send_nb_fn_t)(void* src, size_t size, + ucc_rank_t rank, void *context, + ucc_tl_mlx5_mcast_p2p_completion_obj_t *compl_obj); -typedef int (*ucc_tl_mlx5_mcast_p2p_recv_nb_fn_t)(void* src, size_t size, - ucc_rank_t rank, void *context, - ucc_tl_mlx5_mcast_p2p_completion_obj_t *compl_obj); +typedef ucc_status_t (*ucc_tl_mlx5_mcast_p2p_recv_nb_fn_t)(void* src, size_t size, + ucc_rank_t rank, void *context, + ucc_tl_mlx5_mcast_p2p_completion_obj_t *compl_obj); typedef struct ucc_tl_mlx5_mcast_p2p_interface { ucc_tl_mlx5_mcast_p2p_send_nb_fn_t send_nb; @@ -228,8 +228,8 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm { void *p2p_ctx; ucc_base_lib_t *lib; struct sockaddr_in6 mcast_addr; - int parents[MAX_COMM_POW2]; - int children[MAX_COMM_POW2]; + ucc_rank_t parents[MAX_COMM_POW2]; + ucc_rank_t children[MAX_COMM_POW2]; int nack_requests; int nacks_counter; int n_prep_reliable; diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c index 1cd2f56512..4669c88640 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c @@ -1,22 +1,248 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ #include "tl_mlx5_coll.h" #include "tl_mlx5_mcast_helper.h" +#include "tl_mlx5_mcast_rcache.h" -ucc_status_t ucc_tl_mlx5_mcast_test(ucc_tl_mlx5_mcast_coll_req_t* req /* NOLINT */) +static inline ucc_status_t ucc_tl_mlx5_mcast_r_window_recycle(ucc_tl_mlx5_mcast_coll_comm_t *comm, + ucc_tl_mlx5_mcast_coll_req_t *req) { - return UCC_ERR_NOT_SUPPORTED; + ucc_status_t status = UCC_OK; + int wsize = comm->wsize; + int num_free_win = wsize - (comm->psn - comm->last_acked); + int req_completed = (req->to_send == 0 && req->to_recv == 0); + struct pp_packet *pp = NULL; + + ucc_assert(comm->recv_drop_packet_in_progress == false); + ucc_assert(req->to_send >= 0); + + /* When do we need to perform reliability protocol: + 1. Always in the end of the window + 2. For the zcopy case: in the end of collective, because we can't signal completion + before made sure that children received the data - user can modify buffer */ + + ucc_assert(num_free_win >= 0); + + if (!num_free_win || (req->proto == MCAST_PROTO_ZCOPY && req_completed)) { + status = ucc_tl_mlx5_mcast_reliable(comm); + if (UCC_OK != status) { + return status; + } + + comm->n_mcast_reliable++; + + for (;comm->last_acked < comm->psn; comm->last_acked++) { + pp = comm->r_window[comm->last_acked & (wsize-1)]; + ucc_assert(pp != &comm->dummy_packet); + comm->r_window[comm->last_acked & (wsize-1)] = &comm->dummy_packet; + + pp->context = 0; + ucc_list_add_tail(&comm->bpool, &pp->super); + } + + if (!req_completed) { + status = ucc_tl_mlx5_mcast_prepare_reliable(comm, req, req->root); + if (ucc_unlikely(UCC_OK != status)) { + return status; + } + } + } + + return UCC_OK; +} + +static inline ucc_status_t ucc_tl_mlx5_mcast_do_bcast(ucc_tl_mlx5_mcast_coll_req_t *req) +{ + ucc_status_t status = UCC_OK; + ucc_tl_mlx5_mcast_coll_comm_t *comm = req->comm; + int zcopy = req->proto != MCAST_PROTO_EAGER; + int wsize = comm->wsize; + int num_free_win; + int num_sent; + int to_send; + int to_recv; + int to_recv_left; + int pending_q_size; + + if (ucc_unlikely(comm->recv_drop_packet_in_progress)) { + /* wait till parent resend the dropped packet */ + return UCC_INPROGRESS; + } + + if (comm->reliable_in_progress) { + /* wait till all the children send their ACK for current window */ + status = ucc_tl_mlx5_mcast_r_window_recycle(comm, req); + if (UCC_OK != status) { + return status; + } + } + + if (req->to_send || req->to_recv) { + num_free_win = wsize - (comm->psn - comm->last_acked); + + /* Send data if i'm root and there is a space in the window */ + if (num_free_win && req->am_root) { + num_sent = req->num_packets - req->to_send; + ucc_assert(req->to_send > 0); + ucc_assert(req->first_send_psn + num_sent < comm->last_acked + wsize); + if (req->first_send_psn + num_sent < comm->last_acked + wsize && + req->to_send) { + /* How many to send: either all that are left (if they fit into window) or + up to the window limit */ + to_send = ucc_min(req->to_send, + comm->last_acked + wsize - (req->first_send_psn + num_sent)); + ucc_tl_mlx5_mcast_send(comm, req, to_send, zcopy); + + num_free_win = wsize - (comm->psn - comm->last_acked); + } + } + + status = ucc_tl_mlx5_mcast_prepare_reliable(comm, req, req->root); + if (ucc_unlikely(UCC_OK != status)) { + return status; + } + + if (num_free_win && req->to_recv) { + /* How many to recv: either all that are left or up to the window limit. */ + pending_q_size = 0; + to_recv = ucc_min(num_free_win, req->to_recv); + to_recv_left = ucc_tl_mlx5_mcast_recv(comm, req, to_recv, &pending_q_size); + + if (to_recv == to_recv_left) { + /* We didn't receive anything: increase the stalled counter and get ready for + drop event */ + if (comm->stalled++ >= DROP_THRESHOLD) { + + tl_trace(comm->lib, "Did not receive the packet with psn in" + " current window range, so get ready for drop" + " event. pending_q_size %d current comm psn %d" + " last_acked psn %d stall threshold %d ", + pending_q_size, comm->psn, comm->last_acked, + DROP_THRESHOLD); + + status = ucc_tl_mlx5_mcast_bcast_check_drop(comm, req); + if (UCC_INPROGRESS == status) { + return status; + } + } + } else if (to_recv_left < 0) { + /* a failure happend during cq polling */ + return UCC_ERR_NO_MESSAGE; + } else { + comm->stalled = 0; + comm->timer = 0; + } + } + + /* This function will check if we have to do a round of reliability protocol */ + status = ucc_tl_mlx5_mcast_r_window_recycle(comm, req); + if (UCC_OK != status) { + return status; + } + } + + if (req->to_send || req->to_recv || (zcopy && comm->psn != comm->last_acked)) { + return UCC_INPROGRESS; + } else { + return status; + } +} + + +ucc_status_t ucc_tl_mlx5_mcast_test(ucc_tl_mlx5_mcast_coll_req_t* req) +{ + ucc_status_t status = UCC_OK; + + ucc_assert(req->comm->psn >= req->start_psn); + + status = ucc_tl_mlx5_mcast_do_bcast(req); + if (UCC_INPROGRESS != status) { + ucc_assert(req->comm->ctx != NULL); + ucc_tl_mlx5_mcast_mem_deregister(req->comm->ctx, req->rreg); + req->rreg = NULL; + } + + return status; +} + +static inline ucc_status_t ucc_tl_mlx5_mcast_prepare_bcast(void* buf, size_t size, ucc_rank_t root, + ucc_tl_mlx5_mcast_coll_comm_t *comm, + ucc_tl_mlx5_mcast_coll_req_t *req) +{ + ucc_status_t status; + ucc_tl_mlx5_mcast_reg_t *reg; + + req->comm = comm; + req->ptr = buf; + req->length = size; + req->root = root; + req->am_root = (root == comm->rank); + req->mr = comm->pp_mr; + req->rreg = NULL; + req->proto = (req->length < comm->max_eager) ? MCAST_PROTO_EAGER : MCAST_PROTO_ZCOPY; + + status = ucc_tl_mlx5_mcast_prepare_reliable(comm, req, req->root); + if (ucc_unlikely(UCC_OK != status)) { + return status; + } + + if (req->am_root) { + if (req->proto != MCAST_PROTO_EAGER) { + status = ucc_tl_mlx5_mcast_mem_register(comm->ctx, req->ptr, req->length, ®); + if (UCC_OK != status) { + return status; + } + req->rreg = reg; + req->mr = reg->mr; + } + } + + req->offset = 0; + req->start_psn = comm->last_psn; + req->num_packets = ucc_max(ucc_div_round_up(req->length, comm->max_per_packet), 1); + req->last_pkt_len = req->length - (req->num_packets - 1)*comm->max_per_packet; + + ucc_assert(req->last_pkt_len > 0 && req->last_pkt_len <= comm->max_per_packet); + + comm->last_psn += req->num_packets; + req->first_send_psn = req->start_psn; + req->to_send = req->am_root ? req->num_packets : 0; + req->to_recv = req->am_root ? 0 : req->num_packets; + + return UCC_OK; } -ucc_status_t mcast_coll_do_bcast(void* buf, size_t size, ucc_rank_t root, void *mr, /* NOLINT */ - ucc_tl_mlx5_mcast_coll_comm_t *comm, /* NOLINT */ - ucc_tl_mlx5_mcast_coll_req_t **task_req_handle /* NOLINT */) +ucc_status_t ucc_tl_mlx5_mcast_coll_do_bcast(void* buf, size_t size, ucc_rank_t root, + ucc_tl_mlx5_mcast_coll_comm_t *comm, + ucc_tl_mlx5_mcast_coll_req_t **task_req_handle) { - return UCC_ERR_NOT_SUPPORTED; + ucc_status_t status; + ucc_tl_mlx5_mcast_coll_req_t *req; + + tl_trace(comm->lib, "MCAST bcast start, buf %p, size %ld, root %d, comm %d, " + "comm_size %d, am_i_root %d comm->psn = %d \n", + buf, size, root, comm->comm_id, comm->commsize, comm->rank == + root, comm->psn ); + + req = ucc_calloc(1, sizeof(ucc_tl_mlx5_mcast_coll_req_t), "mcast_req"); + if (!req) { + return UCC_ERR_NO_MEMORY; + } + + status = ucc_tl_mlx5_mcast_prepare_bcast(buf, size, root, comm, req); + if (UCC_OK != status) { + ucc_free(req); + return status; + } + + status = UCC_INPROGRESS; + *task_req_handle = req; + + return status; } ucc_status_t ucc_tl_mlx5_mcast_bcast_start(ucc_coll_task_t *coll_task) @@ -35,8 +261,8 @@ ucc_status_t ucc_tl_mlx5_mcast_bcast_start(ucc_coll_task_t *coll_task) task->bcast_mcast.req_handle = NULL; - status = mcast_coll_do_bcast(buf, data_size, root, NULL, comm, - &task->bcast_mcast.req_handle); + status = ucc_tl_mlx5_mcast_coll_do_bcast(buf, data_size, root, comm, + &task->bcast_mcast.req_handle); if (status < 0) { tl_error(UCC_TASK_LIB(task), "mcast_coll_do_bcast failed:%d", status); coll_task->status = status; @@ -50,23 +276,16 @@ ucc_status_t ucc_tl_mlx5_mcast_bcast_start(ucc_coll_task_t *coll_task) void ucc_tl_mlx5_mcast_collective_progress(ucc_coll_task_t *coll_task) { - ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t); - ucc_status_t status = UCC_OK; - ucc_tl_mlx5_mcast_coll_req_t *req = task->bcast_mcast.req_handle; + ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t); + ucc_tl_mlx5_mcast_coll_req_t *req = task->bcast_mcast.req_handle; if (req != NULL) { - status = ucc_tl_mlx5_mcast_test(req); - if (UCC_OK == status) { - coll_task->status = UCC_OK; - ucc_free(req); - task->bcast_mcast.req_handle = NULL; - } + coll_task->status = ucc_tl_mlx5_mcast_test(req); } } ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task) { - task->super.post = ucc_tl_mlx5_mcast_bcast_start; task->super.progress = ucc_tl_mlx5_mcast_collective_progress; diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h index 47ddc301aa..74385b1573 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_context.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_context.c index 5361f1deb5..192000ee86 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_context.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_context.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c index de3af55a60..81d142b3a1 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -417,15 +417,21 @@ ucc_status_t ucc_tl_mlx5_fini_mcast_group(ucc_tl_mlx5_mcast_coll_context_t *ctx, ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm) { - int ret; - ucc_status_t status; + ucc_tl_mlx5_mcast_context_t *mcast_ctx = ucc_container_of(comm->ctx, ucc_tl_mlx5_mcast_context_t, mcast_context); + ucc_tl_mlx5_context_t *mlx5_ctx = ucc_container_of(mcast_ctx, ucc_tl_mlx5_context_t, mcast); + ucc_context_h context = mlx5_ctx->super.super.ucc_context; + int ret; + ucc_status_t status; tl_debug(comm->lib, "cleaning mcast comm: %p, id %d, mlid %x", comm, comm->comm_id, comm->mcast_lid); - if (UCC_OK != (status = ucc_tl_mlx5_mcast_reliable(comm))) { - // TODO handle (UCC_INPROGRESS == ret) - tl_error(comm->lib, "couldn't clean mcast team: relibality progress status %d", + while (UCC_INPROGRESS == (status = ucc_tl_mlx5_mcast_reliable(comm))) { + ucc_context_progress(context); + } + + if (UCC_OK != status) { + tl_error(comm->lib, "failed to clean mcast team: relibality progress status %d", status); return status; } diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h index bd3e7521fb..427039316d 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -319,7 +319,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_reliable(ucc_tl_mlx5_mcast_coll_com } } - status = ucc_tl_mlx5_mcast_check_nack_requests_all(comm); + status = ucc_tl_mlx5_mcast_check_nack_requests(comm, UINT32_MAX); if (UCC_OK != status) { return status; } diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c new file mode 100644 index 0000000000..a201944ecf --- /dev/null +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c @@ -0,0 +1,378 @@ +/** + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include "tl_mlx5_mcast_progress.h" + +static ucc_status_t ucc_tl_mlx5_mcast_recv_completion(ucc_tl_mlx5_mcast_p2p_completion_obj_t *obj); + +static ucc_status_t ucc_tl_mlx5_mcast_send_completion(ucc_tl_mlx5_mcast_p2p_completion_obj_t *obj); + +static ucc_status_t ucc_tl_mlx5_mcast_dummy_completion(ucc_tl_mlx5_mcast_p2p_completion_obj_t *obj) // NOLINT +{ + return UCC_OK; +} + +static ucc_tl_mlx5_mcast_p2p_completion_obj_t dummy_completion_obj = { + .compl_cb = ucc_tl_mlx5_mcast_dummy_completion, +}; + +static inline ucc_status_t ucc_tl_mlx5_mcast_resend_packet_reliable(ucc_tl_mlx5_mcast_coll_comm_t *comm, + int p2p_pkt_id) +{ + uint32_t psn = comm->p2p_pkt[p2p_pkt_id].psn; + struct pp_packet *pp = comm->r_window[psn % comm->wsize]; + ucc_status_t status; + + ucc_assert(pp->psn == psn); + + tl_trace(comm->lib, "[comm %d, rank %d] Send data NACK: to %d, psn %d, context %ld\n", + comm->comm_id, comm->rank, + comm->p2p_pkt[p2p_pkt_id].from, psn, pp->context); + + status = comm->params.p2p_iface.send_nb((void*) (pp->context ? pp->context : pp->buf), + pp->length, comm->p2p_pkt[p2p_pkt_id].from, + comm->p2p_ctx, &dummy_completion_obj); + if (status < 0) { + return status; + } + + status = comm->params.p2p_iface.recv_nb(&comm->p2p_pkt[p2p_pkt_id], + sizeof(struct packet), comm->p2p_pkt[p2p_pkt_id].from, + comm->p2p_ctx, GET_COMPL_OBJ(comm, + ucc_tl_mlx5_mcast_recv_completion, p2p_pkt_id, NULL)); + if (status < 0) { + return status; + } + + return UCC_OK; +} + +ucc_status_t ucc_tl_mlx5_mcast_check_nack_requests(ucc_tl_mlx5_mcast_coll_comm_t *comm, uint32_t psn) +{ + ucc_status_t status = UCC_OK; + int i; + struct pp_packet *pp; + + if (!comm->nack_requests) { + return UCC_OK; + } + + if (psn != UINT32_MAX) { + for (i=0; ichild_n; i++) { + if (psn == comm->p2p_pkt[i].psn && + comm->p2p_pkt[i].type == MCAST_P2P_NEED_NACK_SEND) { + status = ucc_tl_mlx5_mcast_resend_packet_reliable(comm, i); + if (status != UCC_OK) { + break; + } + comm->p2p_pkt[i].type = MCAST_P2P_ACK; + comm->nack_requests--; + } + } + } else { + for (i=0; ichild_n; i++){ + if (comm->p2p_pkt[i].type == MCAST_P2P_NEED_NACK_SEND) { + psn = comm->p2p_pkt[i].psn; + pp = comm->r_window[psn % comm->wsize]; + if (psn == pp->psn) { + status = ucc_tl_mlx5_mcast_resend_packet_reliable(comm, i); + if (status < 0) { + break; + } + comm->p2p_pkt[i].type = MCAST_P2P_ACK; + comm->nack_requests--; + } + } + } + } + + return status; +} + +static inline int ucc_tl_mlx5_mcast_find_nack_psn(ucc_tl_mlx5_mcast_coll_comm_t* comm, + ucc_tl_mlx5_mcast_coll_req_t *req) +{ + int psn = ucc_max(comm->last_acked, req->start_psn); + int max_search_psn = ucc_min(req->start_psn + req->num_packets, + comm->last_acked + comm->wsize + 1); + + for (; psn < max_search_psn; psn++) { + if (!PSN_RECEIVED(psn, comm)) { + break; + } + } + + ucc_assert(psn < max_search_psn); + + return psn; +} + +static inline ucc_rank_t ucc_tl_mlx5_mcast_get_nack_parent(ucc_tl_mlx5_mcast_coll_req_t *req) +{ + return req->parent; +} + +/* When parent resend the lost packet to a child, this function is called at child side */ +static ucc_status_t ucc_tl_mlx5_mcast_recv_data_completion(ucc_tl_mlx5_mcast_p2p_completion_obj_t *obj) +{ + ucc_status_t status = UCC_OK; + ucc_tl_mlx5_mcast_coll_comm_t *comm = (ucc_tl_mlx5_mcast_coll_comm_t *)obj->data[0]; + struct pp_packet *pp = (struct pp_packet *)obj->data[1]; + ucc_tl_mlx5_mcast_coll_req_t *req = (ucc_tl_mlx5_mcast_coll_req_t *)obj->data[2]; + void *dest; + + tl_trace(comm->lib, "[comm %d, rank %d] Recved data psn %d", comm->comm_id, comm->rank, pp->psn); + + dest = req->ptr + PSN_TO_RECV_OFFSET(pp->psn, req, comm); + memcpy(dest, (void*) pp->buf, pp->length); + req->to_recv--; + comm->r_window[pp->psn % comm->wsize] = pp; + + status = ucc_tl_mlx5_mcast_check_nack_requests(comm, pp->psn); + if (status < 0) { + return status; + } + + comm->psn++; + comm->recv_drop_packet_in_progress = false; + + return status; +} + +static inline ucc_status_t ucc_tl_mlx5_mcast_reliable_send_NACK(ucc_tl_mlx5_mcast_coll_comm_t* comm, + ucc_tl_mlx5_mcast_coll_req_t *req) +{ + struct pp_packet *pp; + ucc_rank_t parent; + ucc_status_t status; + + struct packet p = { + .type = MCAST_P2P_NACK, + .psn = ucc_tl_mlx5_mcast_find_nack_psn(comm, req), + .from = comm->rank, + .comm_id = comm->comm_id, + }; + + parent = ucc_tl_mlx5_mcast_get_nack_parent(req); + + comm->nacks_counter++; + + status = comm->params.p2p_iface.send_nb(&p, sizeof(struct packet), parent, + comm->p2p_ctx, &dummy_completion_obj); + if (status < 0) { + return status; + } + + tl_trace(comm->lib, "[comm %d, rank %d] Sent NAK : parent %d, psn %d", + comm->comm_id, comm->rank, parent, p.psn); + + // Prepare to obtain the data. + pp = ucc_tl_mlx5_mcast_buf_get_free(comm); + pp->psn = p.psn; + pp->length = PSN_TO_RECV_LEN(pp->psn, req, comm); + + comm->recv_drop_packet_in_progress = true; + + status = comm->params.p2p_iface.recv_nb((void*) pp->buf, + pp->length, parent, + comm->p2p_ctx, GET_COMPL_OBJ(comm, + ucc_tl_mlx5_mcast_recv_data_completion, pp, req)); + if (status < 0) { + return status; + } + + return UCC_INPROGRESS; +} + +ucc_status_t ucc_tl_mlx5_mcast_reliable_send(ucc_tl_mlx5_mcast_coll_comm_t *comm) +{ + ucc_rank_t i; + ucc_rank_t parent; + ucc_status_t status; + + tl_trace(comm->lib, "comm %p, psn %d, last_acked %d, n_parent %d", + comm, comm->psn, comm->last_acked, comm->parent_n); + + ucc_assert(!comm->reliable_in_progress); + + for (i=0; iparent_n; i++) { + parent = comm->parents[i]; + comm->p2p_spkt[i].type = MCAST_P2P_ACK; + comm->p2p_spkt[i].psn = comm->last_acked + comm->wsize; + comm->p2p_spkt[i].comm_id = comm->comm_id; + + tl_trace(comm->lib, "rank %d, Posting SEND to parent %d, n_parent %d, psn %d", + comm->rank, parent, comm->parent_n, comm->psn); + + status = comm->params.p2p_iface.send_nb(&comm->p2p_spkt[i], + sizeof(struct packet), parent, + comm->p2p_ctx, GET_COMPL_OBJ(comm, + ucc_tl_mlx5_mcast_send_completion, i, NULL)); + if (status < 0) { + return status; + } + } + + return UCC_OK; +} + +static ucc_status_t ucc_tl_mlx5_mcast_recv_completion(ucc_tl_mlx5_mcast_p2p_completion_obj_t *obj) +{ + ucc_tl_mlx5_mcast_coll_comm_t *comm = (ucc_tl_mlx5_mcast_coll_comm_t*)obj->data[0]; + int pkt_id = (int)obj->data[1]; + uint32_t psn; + struct pp_packet *pp; + ucc_status_t status; + + ucc_assert(comm->comm_id == comm->p2p_pkt[pkt_id].comm_id); + + if (comm->p2p_pkt[pkt_id].type != MCAST_P2P_ACK) { + ucc_assert(comm->p2p_pkt[pkt_id].type == MCAST_P2P_NACK); + psn = comm->p2p_pkt[pkt_id].psn; + pp = comm->r_window[psn % comm->wsize]; + + tl_trace(comm->lib, "[comm %d, rank %d] Got NACK: from %d, psn %d, avail %d", + comm->comm_id, comm->rank, + comm->p2p_pkt[pkt_id].from, psn, pp->psn == psn); + + if (pp->psn == psn) { + status = ucc_tl_mlx5_mcast_resend_packet_reliable(comm, pkt_id); + if (status < 0) { + return status; + } + } else { + comm->p2p_pkt[pkt_id].type = MCAST_P2P_NEED_NACK_SEND; + comm->nack_requests++; + } + + } else { + comm->racks_n++; + } + + ucc_mpool_put(obj); /* return the completion object back to the mem pool compl_objects_mp */ + + return UCC_OK; +} + +static ucc_status_t ucc_tl_mlx5_mcast_send_completion(ucc_tl_mlx5_mcast_p2p_completion_obj_t *obj) +{ + ucc_tl_mlx5_mcast_coll_comm_t *comm = (ucc_tl_mlx5_mcast_coll_comm_t*)obj->data[0]; + + comm->sacks_n++; + ucc_mpool_put(obj); + return UCC_OK; +} + +static inline int add_uniq(ucc_rank_t *arr, uint32_t *len, ucc_rank_t value) +{ + int i; + + for (i=0; i<(*len); i++) { + if (arr[i] == value) { + return 0; + } + } + + arr[*len] = value; + (*len)++; + return 1; +} + +ucc_status_t ucc_tl_mlx5_mcast_prepare_reliable(ucc_tl_mlx5_mcast_coll_comm_t *comm, + ucc_tl_mlx5_mcast_coll_req_t *req, + ucc_rank_t root) +{ + ucc_rank_t mask = 1; + ucc_rank_t vrank = TO_VIRTUAL(comm->rank, comm->commsize, root); + ucc_rank_t child; + ucc_status_t status; + + ucc_assert(comm->commsize <= pow(2, MAX_COMM_POW2)); + + while (mask < comm->commsize) { + if (vrank & mask) { + req->parent = TO_ORIGINAL((vrank ^ mask), comm->commsize, root); + add_uniq(comm->parents, &comm->parent_n, req->parent); + break; + } else { + child = vrank ^ mask; + if (child < comm->commsize) { + child = TO_ORIGINAL(child, comm->commsize, root); + if (add_uniq(comm->children, &comm->child_n, child)) { + tl_trace(comm->lib, "rank %d, Posting RECV from child %d, n_child %d, psn %d", + comm->rank, child, comm->child_n, comm->psn); + + status = comm->params.p2p_iface.recv_nb(&comm->p2p_pkt[comm->child_n - 1], + sizeof(struct packet), child, + comm->p2p_ctx, GET_COMPL_OBJ(comm, + ucc_tl_mlx5_mcast_recv_completion, comm->child_n - 1, req)); + if (status < 0) { + return status; + } + } + } + } + + mask <<= 1; + } + + return UCC_OK; +} + +static inline uint64_t ucc_tl_mlx5_mcast_get_timer(void) +{ + double t_second = ucc_get_time(); + return (uint64_t) (t_second * 1000000); +} + +ucc_status_t ucc_tl_mlx5_mcast_bcast_check_drop(ucc_tl_mlx5_mcast_coll_comm_t *comm, + ucc_tl_mlx5_mcast_coll_req_t *req) +{ + ucc_status_t status = UCC_OK; + + if (comm->timer == 0) { + comm->timer = ucc_tl_mlx5_mcast_get_timer(); + } else { + if (ucc_tl_mlx5_mcast_get_timer() - comm->timer >= comm->ctx->params.timeout) { + tl_trace(comm->lib, "[REL] time out %d", comm->psn); + status = ucc_tl_mlx5_mcast_reliable_send_NACK(comm, req); + comm->timer = 0; + } + } + + return status; +} + +ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *comm, + ucc_tl_mlx5_mcast_coll_req_t *req, + struct pp_packet* pp) +{ + ucc_status_t status = UCC_OK; + void *dest; + ucc_assert(pp->psn >= req->start_psn && + pp->psn < req->start_psn + req->num_packets); + + ucc_assert(pp->length == PSN_TO_RECV_LEN(pp->psn, req, comm)); + ucc_assert(pp->context == 0); + + if (pp->length > 0 ) { + dest = req->ptr + PSN_TO_RECV_OFFSET(pp->psn, req, comm); + memcpy(dest, (void*) pp->buf, pp->length); + } + + comm->r_window[pp->psn & (comm->wsize-1)] = pp; + status = ucc_tl_mlx5_mcast_check_nack_requests(comm, pp->psn); + if (status < 0) { + return status; + } + + req->to_recv--; + comm->psn++; + ucc_assert(comm->recv_drop_packet_in_progress == false); + + return status; +} + diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.h index da30a4b1c0..1bceb89976 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -42,23 +42,23 @@ _req; \ }) -int ucc_tl_mlx5_mcast_prepare_reliable(ucc_tl_mlx5_mcast_coll_comm_t *comm, - ucc_tl_mlx5_mcast_coll_req_t *req, - ucc_rank_t root); +ucc_status_t ucc_tl_mlx5_mcast_prepare_reliable(ucc_tl_mlx5_mcast_coll_comm_t *comm, + ucc_tl_mlx5_mcast_coll_req_t *req, + ucc_rank_t root); ucc_status_t ucc_tl_mlx5_mcast_bcast_check_drop(ucc_tl_mlx5_mcast_coll_comm_t *comm, ucc_tl_mlx5_mcast_coll_req_t *req); ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *comm, - ucc_tl_mlx5_mcast_coll_req_t *req, - struct pp_packet* pp); + ucc_tl_mlx5_mcast_coll_req_t *req, + struct pp_packet* pp); ucc_status_t ucc_tl_mlx5_mcast_check_nack_requests(ucc_tl_mlx5_mcast_coll_comm_t *comm, - uint32_t psn); + uint32_t psn); ucc_status_t ucc_tl_mlx5_mcast_reliable_send(ucc_tl_mlx5_mcast_coll_comm_t* comm); -ucc_status_t ucc_tl_mlx5_mcast_check_nack_requests_all(ucc_tl_mlx5_mcast_coll_comm_t* comm); +ucc_status_t ucc_tl_mlx5_mcast_check_nack_requests(ucc_tl_mlx5_mcast_coll_comm_t* comm, uint32_t psn); #endif /* ifndef TL_MLX5_MCAST_PROGRESS_H_ */ diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_rcache.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_rcache.c index 41caf5693a..47f73e485b 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_rcache.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_rcache.c @@ -109,9 +109,8 @@ ucc_tl_mlx5_mcast_mem_register(ucc_tl_mlx5_mcast_coll_context_t *ctx, return UCC_OK; } -ucc_status_t -ucc_tl_mlx5_mcast_mem_deregister(ucc_tl_mlx5_mcast_coll_context_t *ctx, - ucc_tl_mlx5_mcast_reg_t *reg) +void ucc_tl_mlx5_mcast_mem_deregister(ucc_tl_mlx5_mcast_coll_context_t *ctx, + ucc_tl_mlx5_mcast_reg_t *reg) { ucc_tl_mlx5_mcast_rcache_region_t *region; ucc_rcache_t *rcache; @@ -119,15 +118,13 @@ ucc_tl_mlx5_mcast_mem_deregister(ucc_tl_mlx5_mcast_coll_context_t *ctx, rcache = ctx->rcache; if (reg == NULL) { - return UCC_OK; + return; } ucc_assert(rcache != NULL); tl_trace(ctx->lib, "memory deregister mr %p", reg->mr); region = ucc_container_of(reg, ucc_tl_mlx5_mcast_rcache_region_t, reg); ucc_rcache_region_put(rcache, ®ion->super); - - return UCC_OK; } static ucc_rcache_ops_t ucc_tl_mlx5_rcache_ops = { diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_rcache.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_rcache.h index e1836704ad..da90f562a1 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_rcache.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_rcache.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -13,5 +13,5 @@ ucc_status_t ucc_tl_mlx5_mcast_mem_register(ucc_tl_mlx5_mcast_coll_context_t *ctx, void *addr, size_t length, ucc_tl_mlx5_mcast_reg_t **reg); -ucc_status_t ucc_tl_mlx5_mcast_mem_deregister(ucc_tl_mlx5_mcast_coll_context_t *ctx, - ucc_tl_mlx5_mcast_reg_t *reg); +void ucc_tl_mlx5_mcast_mem_deregister(ucc_tl_mlx5_mcast_coll_context_t *ctx, + ucc_tl_mlx5_mcast_reg_t *reg); diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c index 6cac983bf0..1821b4375c 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -329,7 +329,7 @@ ucc_status_t ucc_tl_mlx5_mcast_team_test(ucc_base_team_t *team) case TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_POST: { - /* rank 0 has already called rdma_join_multicast() + /* rank 0 has already called rdma_join_multicast() * it is time to wait for the rdma event to confirm the join */ status = ucc_tl_mlx5_mcast_join_mcast_test(comm->ctx, &comm->event, 1); if (UCC_OK != status) { @@ -437,7 +437,7 @@ ucc_status_t ucc_tl_mlx5_mcast_team_test(ucc_base_team_t *team) status = ucc_tl_mlx5_mcast_coll_setup_comm_resources(comm); if (UCC_OK != status) { return status; - } + } tl_debug(comm->lib, "initialized tl mcast team: %p", tl_team); tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_READY; @@ -525,7 +525,7 @@ ucc_status_t ucc_tl_mlx5_mcast_team_test(ucc_base_team_t *team) case TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_POST: { - /* none-root rank has already called rdma_join_multicast() + /* none-root rank has already called rdma_join_multicast() * it is time to wait for the rdma event to confirm the join */ status = ucc_tl_mlx5_mcast_join_mcast_test(comm->ctx, &comm->event, 0); if (UCC_OK != status) { @@ -568,7 +568,7 @@ ucc_status_t ucc_tl_mlx5_mcast_team_test(ucc_base_team_t *team) status = ucc_tl_mlx5_mcast_coll_setup_comm_resources(comm); if (UCC_OK != status) { return status; - } + } tl_debug(comm->lib, "initialized tl mcast team: %p", tl_team); tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_READY; @@ -586,7 +586,7 @@ ucc_status_t ucc_tl_mlx5_mcast_team_test(ucc_base_team_t *team) { tl_error(comm->lib, "unknown state during mcast team: %p create", tl_team); return UCC_ERR_NO_RESOURCE; - } + } } } } diff --git a/src/components/tl/mlx5/tl_mlx5_coll.c b/src/components/tl/mlx5/tl_mlx5_coll.c index 90e224f9f6..861d4a4c67 100644 --- a/src/components/tl/mlx5/tl_mlx5_coll.c +++ b/src/components/tl/mlx5/tl_mlx5_coll.c @@ -45,7 +45,14 @@ ucc_status_t ucc_tl_mlx5_bcast_mcast_init(ucc_base_coll_args_t *coll_args, ucc_status_t ucc_tl_mlx5_task_finalize(ucc_coll_task_t *coll_task) { - ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t); + ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t); + ucc_tl_mlx5_mcast_coll_req_t *req = task->bcast_mcast.req_handle; + + if (req != NULL) { + ucc_assert(coll_task->status != UCC_INPROGRESS); + ucc_free(req); + task->bcast_mcast.req_handle = NULL; + } tl_trace(UCC_TASK_LIB(task), "finalizing task %p", task); ucc_tl_mlx5_put_task(task); diff --git a/src/components/tl/mlx5/tl_mlx5_coll.h b/src/components/tl/mlx5/tl_mlx5_coll.h index 642dd71581..eb441bdcdf 100644 --- a/src/components/tl/mlx5/tl_mlx5_coll.h +++ b/src/components/tl/mlx5/tl_mlx5_coll.h @@ -79,6 +79,7 @@ ucc_tl_mlx5_get_task(ucc_base_coll_args_t *coll_args, ucc_base_team_t *team) UCC_TL_MLX5_PROFILE_REQUEST_NEW(task, "tl_mlx5_task", 0); ucc_coll_task_init(&task->super, coll_args, team); + task->bcast_mcast.req_handle = NULL; return task; }