Skip to content

Commit

Permalink
TL/MLX5: adding mcast bcast algo
Browse files Browse the repository at this point in the history
  • Loading branch information
MamziB committed Mar 7, 2024
1 parent e13d962 commit 1276ca3
Show file tree
Hide file tree
Showing 16 changed files with 670 additions and 58 deletions.
1 change: 1 addition & 0 deletions src/components/tl/mlx5/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ mcast = \
mcast/p2p/ucc_tl_mlx5_mcast_p2p.h \
mcast/p2p/ucc_tl_mlx5_mcast_p2p.c \
mcast/tl_mlx5_mcast_progress.h \
mcast/tl_mlx5_mcast_progress.c \
mcast/tl_mlx5_mcast_helper.h \
mcast/tl_mlx5_mcast_helper.c \
mcast/tl_mlx5_mcast_team.c
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down
18 changes: 9 additions & 9 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down Expand Up @@ -64,14 +64,14 @@ typedef struct ucc_tl_mlx5_mcast_p2p_completion_obj {

typedef int (*ucc_tl_mlx5_mcast_p2p_wait_cb_fn_t)(void *wait_arg);

typedef int (*ucc_tl_mlx5_mcast_p2p_send_nb_fn_t)(void* src, size_t size,
ucc_rank_t rank, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t *compl_obj);
typedef ucc_status_t (*ucc_tl_mlx5_mcast_p2p_send_nb_fn_t)(void* src, size_t size,
ucc_rank_t rank, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t *compl_obj);


typedef int (*ucc_tl_mlx5_mcast_p2p_recv_nb_fn_t)(void* src, size_t size,
ucc_rank_t rank, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t *compl_obj);
typedef ucc_status_t (*ucc_tl_mlx5_mcast_p2p_recv_nb_fn_t)(void* src, size_t size,
ucc_rank_t rank, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t *compl_obj);

typedef struct ucc_tl_mlx5_mcast_p2p_interface {
ucc_tl_mlx5_mcast_p2p_send_nb_fn_t send_nb;
Expand Down Expand Up @@ -228,8 +228,8 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm {
void *p2p_ctx;
ucc_base_lib_t *lib;
struct sockaddr_in6 mcast_addr;
int parents[MAX_COMM_POW2];
int children[MAX_COMM_POW2];
ucc_rank_t parents[MAX_COMM_POW2];
ucc_rank_t children[MAX_COMM_POW2];
int nack_requests;
int nacks_counter;
int n_prep_reliable;
Expand Down
257 changes: 238 additions & 19 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c
Original file line number Diff line number Diff line change
@@ -1,22 +1,248 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/

#include "tl_mlx5_coll.h"
#include "tl_mlx5_mcast_helper.h"
#include "tl_mlx5_mcast_rcache.h"

ucc_status_t ucc_tl_mlx5_mcast_test(ucc_tl_mlx5_mcast_coll_req_t* req /* NOLINT */)
static inline ucc_status_t ucc_tl_mlx5_mcast_r_window_recycle(ucc_tl_mlx5_mcast_coll_comm_t *comm,
ucc_tl_mlx5_mcast_coll_req_t *req)
{
return UCC_ERR_NOT_SUPPORTED;
ucc_status_t status = UCC_OK;
int wsize = comm->wsize;
int num_free_win = wsize - (comm->psn - comm->last_acked);
int req_completed = (req->to_send == 0 && req->to_recv == 0);
struct pp_packet *pp = NULL;

ucc_assert(comm->recv_drop_packet_in_progress == false);
ucc_assert(req->to_send >= 0);

/* When do we need to perform reliability protocol:
1. Always in the end of the window
2. For the zcopy case: in the end of collective, because we can't signal completion
before made sure that children received the data - user can modify buffer */

ucc_assert(num_free_win >= 0);

if (!num_free_win || (req->proto == MCAST_PROTO_ZCOPY && req_completed)) {
status = ucc_tl_mlx5_mcast_reliable(comm);
if (UCC_OK != status) {
return status;
}

comm->n_mcast_reliable++;

for (;comm->last_acked < comm->psn; comm->last_acked++) {
pp = comm->r_window[comm->last_acked & (wsize-1)];
ucc_assert(pp != &comm->dummy_packet);
comm->r_window[comm->last_acked & (wsize-1)] = &comm->dummy_packet;

pp->context = 0;
ucc_list_add_tail(&comm->bpool, &pp->super);
}

if (!req_completed) {
status = ucc_tl_mlx5_mcast_prepare_reliable(comm, req, req->root);
if (ucc_unlikely(UCC_OK != status)) {
return status;
}
}
}

return UCC_OK;
}

static inline ucc_status_t ucc_tl_mlx5_mcast_do_bcast(ucc_tl_mlx5_mcast_coll_req_t *req)
{
ucc_status_t status = UCC_OK;
ucc_tl_mlx5_mcast_coll_comm_t *comm = req->comm;
int zcopy = req->proto != MCAST_PROTO_EAGER;
int wsize = comm->wsize;
int num_free_win;
int num_sent;
int to_send;
int to_recv;
int to_recv_left;
int pending_q_size;

if (ucc_unlikely(comm->recv_drop_packet_in_progress)) {
/* wait till parent resend the dropped packet */
return UCC_INPROGRESS;
}

if (comm->reliable_in_progress) {
/* wait till all the children send their ACK for current window */
status = ucc_tl_mlx5_mcast_r_window_recycle(comm, req);
if (UCC_OK != status) {
return status;
}
}

if (req->to_send || req->to_recv) {
num_free_win = wsize - (comm->psn - comm->last_acked);

/* Send data if i'm root and there is a space in the window */
if (num_free_win && req->am_root) {
num_sent = req->num_packets - req->to_send;
ucc_assert(req->to_send > 0);
ucc_assert(req->first_send_psn + num_sent < comm->last_acked + wsize);
if (req->first_send_psn + num_sent < comm->last_acked + wsize &&
req->to_send) {
/* How many to send: either all that are left (if they fit into window) or
up to the window limit */
to_send = ucc_min(req->to_send,
comm->last_acked + wsize - (req->first_send_psn + num_sent));
ucc_tl_mlx5_mcast_send(comm, req, to_send, zcopy);

num_free_win = wsize - (comm->psn - comm->last_acked);
}
}

status = ucc_tl_mlx5_mcast_prepare_reliable(comm, req, req->root);
if (ucc_unlikely(UCC_OK != status)) {
return status;
}

if (num_free_win && req->to_recv) {
/* How many to recv: either all that are left or up to the window limit. */
pending_q_size = 0;
to_recv = ucc_min(num_free_win, req->to_recv);
to_recv_left = ucc_tl_mlx5_mcast_recv(comm, req, to_recv, &pending_q_size);

if (to_recv == to_recv_left) {
/* We didn't receive anything: increase the stalled counter and get ready for
drop event */
if (comm->stalled++ >= DROP_THRESHOLD) {

tl_trace(comm->lib, "Did not receive the packet with psn in"
" current window range, so get ready for drop"
" event. pending_q_size %d current comm psn %d"
" last_acked psn %d stall threshold %d ",
pending_q_size, comm->psn, comm->last_acked,
DROP_THRESHOLD);

status = ucc_tl_mlx5_mcast_bcast_check_drop(comm, req);
if (UCC_INPROGRESS == status) {
return status;
}
}
} else if (to_recv_left < 0) {
/* a failure happend during cq polling */
return UCC_ERR_NO_MESSAGE;
} else {
comm->stalled = 0;
comm->timer = 0;
}
}

/* This function will check if we have to do a round of reliability protocol */
status = ucc_tl_mlx5_mcast_r_window_recycle(comm, req);
if (UCC_OK != status) {
return status;
}
}

if (req->to_send || req->to_recv || (zcopy && comm->psn != comm->last_acked)) {
return UCC_INPROGRESS;
} else {
return status;
}
}


ucc_status_t ucc_tl_mlx5_mcast_test(ucc_tl_mlx5_mcast_coll_req_t* req)
{
ucc_status_t status = UCC_OK;

ucc_assert(req->comm->psn >= req->start_psn);

status = ucc_tl_mlx5_mcast_do_bcast(req);
if (UCC_INPROGRESS != status) {
ucc_assert(req->comm->ctx != NULL);
ucc_tl_mlx5_mcast_mem_deregister(req->comm->ctx, req->rreg);
req->rreg = NULL;
}

return status;
}

static inline int ucc_tl_mlx5_mcast_prepare_bcast(void* buf, size_t size, ucc_rank_t root, void *mr,
ucc_tl_mlx5_mcast_coll_comm_t *comm,
ucc_tl_mlx5_mcast_coll_req_t *req)
{
ucc_status_t status;
ucc_tl_mlx5_mcast_reg_t *reg;

req->comm = comm;
req->ptr = buf;
req->length = size;
req->root = root;
req->am_root = (root == comm->rank);
req->mr = comm->pp_mr;
req->rreg = NULL;
req->proto = (req->length < comm->max_eager) ? MCAST_PROTO_EAGER : MCAST_PROTO_ZCOPY;

status = ucc_tl_mlx5_mcast_prepare_reliable(comm, req, req->root);
if (ucc_unlikely(UCC_OK != status)) {
return status;
}

if (req->am_root) {
if (req->proto != MCAST_PROTO_EAGER) {
status = ucc_tl_mlx5_mcast_mem_register(comm->ctx, req->ptr, req->length, &reg);
if (UCC_OK != status) {
return status;
}
req->rreg = reg;
req->mr = reg->mr;
}
}

req->offset = 0;
req->start_psn = comm->last_psn;
req->num_packets = ucc_max(ucc_div_round_up(req->length, comm->max_per_packet), 1);
req->last_pkt_len = req->length - (req->num_packets - 1)*comm->max_per_packet;

ucc_assert(req->last_pkt_len > 0 && req->last_pkt_len <= comm->max_per_packet);

comm->last_psn += req->num_packets;
req->first_send_psn = req->start_psn;
req->to_send = req->am_root ? req->num_packets : 0;
req->to_recv = req->am_root ? 0 : req->num_packets;

return UCC_OK;
}

ucc_status_t mcast_coll_do_bcast(void* buf, size_t size, ucc_rank_t root, void *mr, /* NOLINT */
ucc_tl_mlx5_mcast_coll_comm_t *comm, /* NOLINT */
ucc_tl_mlx5_mcast_coll_req_t **task_req_handle /* NOLINT */)
ucc_status_t ucc_tl_mlx5_mcast_coll_do_bcast(void* buf, size_t size, ucc_rank_t root, void *mr,
ucc_tl_mlx5_mcast_coll_comm_t *comm,
ucc_tl_mlx5_mcast_coll_req_t **task_req_handle)
{
return UCC_ERR_NOT_SUPPORTED;
ucc_status_t status;
ucc_tl_mlx5_mcast_coll_req_t *req;

tl_trace(comm->lib, "MCAST bcast start, buf %p, size %ld, root %d, comm %d, "
"comm_size %d, am_i_root %d comm->psn = %d \n",
buf, size, root, comm->comm_id, comm->commsize, comm->rank ==
root, comm->psn );

req = ucc_calloc(1, sizeof(ucc_tl_mlx5_mcast_coll_req_t), "mcast_req");
if (!req) {
return UCC_ERR_NO_MEMORY;
}

status = ucc_tl_mlx5_mcast_prepare_bcast(buf, size, root, mr, comm, req);
if (UCC_OK != status) {
ucc_free(req);
return status;
}

status = UCC_INPROGRESS;
*task_req_handle = req;

return status;
}

ucc_status_t ucc_tl_mlx5_mcast_bcast_start(ucc_coll_task_t *coll_task)
Expand All @@ -35,8 +261,8 @@ ucc_status_t ucc_tl_mlx5_mcast_bcast_start(ucc_coll_task_t *coll_task)

task->bcast_mcast.req_handle = NULL;

status = mcast_coll_do_bcast(buf, data_size, root, NULL, comm,
&task->bcast_mcast.req_handle);
status = ucc_tl_mlx5_mcast_coll_do_bcast(buf, data_size, root, NULL, comm,
&task->bcast_mcast.req_handle);
if (status < 0) {
tl_error(UCC_TASK_LIB(task), "mcast_coll_do_bcast failed:%d", status);
coll_task->status = status;
Expand All @@ -50,23 +276,16 @@ ucc_status_t ucc_tl_mlx5_mcast_bcast_start(ucc_coll_task_t *coll_task)

void ucc_tl_mlx5_mcast_collective_progress(ucc_coll_task_t *coll_task)
{
ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t);
ucc_status_t status = UCC_OK;
ucc_tl_mlx5_mcast_coll_req_t *req = task->bcast_mcast.req_handle;
ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t);
ucc_tl_mlx5_mcast_coll_req_t *req = task->bcast_mcast.req_handle;

if (req != NULL) {
status = ucc_tl_mlx5_mcast_test(req);
if (UCC_OK == status) {
coll_task->status = UCC_OK;
ucc_free(req);
task->bcast_mcast.req_handle = NULL;
}
coll_task->status = ucc_tl_mlx5_mcast_test(req);
}
}

ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task)
{

task->super.post = ucc_tl_mlx5_mcast_bcast_start;
task->super.progress = ucc_tl_mlx5_mcast_collective_progress;

Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/mlx5/mcast/tl_mlx5_mcast_context.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down
Loading

0 comments on commit 1276ca3

Please sign in to comment.