Skip to content

Commit

Permalink
TL/MLX5: Addressing Sam's comments from July 1st
Browse files Browse the repository at this point in the history
  • Loading branch information
MamziB committed Jul 1, 2024
1 parent 044e785 commit 023835b
Show file tree
Hide file tree
Showing 12 changed files with 326 additions and 341 deletions.
297 changes: 148 additions & 149 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast.h

Large diffs are not rendered by default.

48 changes: 33 additions & 15 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_allgather.c
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,24 @@
#include "tl_mlx5_mcast_allgather.h"
#include <inttypes.h>

/* 32 here is the bit count of ib mcast packet's immediate data */
#define TL_MLX5_MCAST_IB_IMMEDIATE_PACKET_BIT_COUNT 32

#define ONE_SIDED_MAX_PACKET_COUNT(_max_count) \
do { \
int pow2; \
int tmp; \
pow2 = log(ONE_SIDED_RELIABILITY_MAX_TEAM_SIZE) / log(2); \
tmp = TL_MLX5_MCAST_IB_IMMEDIATE_PACKET_BIT_COUNT - pow2; \
pow2 = log(ONE_SIDED_MAX_ALLGATHER_COUNTER) / log(2); \
tmp = tmp - pow2; \
_max_count = pow(2, tmp); \
} while(0);

#define MCAST_ALLGATHER_IN_PROGRESS(_req, _comm) \
(_req->to_send || _req->to_recv || _comm->pending_send || \
_comm->one_sided.rdma_read_in_progress || (NULL != _req->allgather_rkeys_req)) \

static inline ucc_status_t ucc_tl_mlx5_mcast_check_collective(ucc_tl_mlx5_mcast_coll_comm_t *comm,
ucc_tl_mlx5_mcast_coll_req_t *req)
{
Expand Down Expand Up @@ -80,7 +98,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_reset_reliablity(ucc_tl_mlx5_mcast_
ucc_tl_mlx5_mcast_reg_t *reg = NULL;
ucc_status_t status;

ucc_assert(req->ag_counter == comm->ag_under_progress_counter);
ucc_assert(req->ag_counter == comm->allgather_comm.under_progress_counter);

if (comm->one_sided.reliability_enabled && !comm->one_sided.reliability_ready) {
/* initialize the structures needed by reliablity protocol */
Expand Down Expand Up @@ -129,7 +147,7 @@ static inline void ucc_tl_mlx5_mcast_init_async_reliability_slots(ucc_tl_mlx5_mc
ucc_tl_mlx5_mcast_coll_comm_t *comm = req->comm;
void *dest;

ucc_assert(req->ag_counter == comm->ag_under_progress_counter);
ucc_assert(req->ag_counter == comm->allgather_comm.under_progress_counter);

if (ONE_SIDED_ASYNCHRONOUS_PROTO == req->one_sided_reliability_scheme &&
ONE_SIDED_INVALID == comm->one_sided.slots_state) {
Expand Down Expand Up @@ -162,10 +180,10 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_allgather(ucc_tl_mlx5_mcast_coll
}

if (req->to_send || req->to_recv) {
ucc_assert(comm->max_push_send >= comm->pending_send);
ucc_assert(comm->allgather_comm.max_push_send >= comm->pending_send);
if (req->to_send &&
(comm->max_push_send - comm->pending_send) > 0) {
ucc_tl_mlx5_mcast_send_collective(comm, req, ucc_min(comm->max_push_send -
(comm->allgather_comm.max_push_send - comm->pending_send) > 0) {
ucc_tl_mlx5_mcast_send_collective(comm, req, ucc_min(comm->allgather_comm.max_push_send -
comm->pending_send, req->to_send),
zcopy, UCC_COLL_TYPE_ALLGATHER, -1, SIZE_MAX);
}
Expand Down Expand Up @@ -223,7 +241,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_allgather(ucc_tl_mlx5_mcast_coll
}
}

ucc_status_t ucc_tl_mlx5_mcast_test_allgather(ucc_tl_mlx5_mcast_coll_req_t* req)
static inline ucc_status_t ucc_tl_mlx5_mcast_test_allgather(ucc_tl_mlx5_mcast_coll_req_t* req)
{
ucc_status_t status;

Expand Down Expand Up @@ -281,12 +299,12 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_prepare_allgather(void* sbuf, void
req->num_packets = 1;
}

ONE_SIDED_MAX_PACKET_COUNT(comm->ag_max_num_packets);
ONE_SIDED_MAX_PACKET_COUNT(comm->allgather_comm.max_num_packets);

if (comm->ag_max_num_packets < req->num_packets) {
if (comm->allgather_comm.max_num_packets < req->num_packets) {
tl_warn(comm->lib,
"msg size is %ld but max supported msg size of mcast allgather is %d",
req->length, comm->ag_max_num_packets * comm->max_per_packet);
req->length, comm->allgather_comm.max_num_packets * comm->max_per_packet);
return UCC_ERR_NOT_SUPPORTED;
}

Expand All @@ -312,11 +330,11 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_prepare_allgather(void* sbuf, void
req->one_sided_reliability_scheme = ONE_SIDED_NO_RELIABILITY;
}

req->ag_counter = comm->ag_counter;
req->ag_counter = comm->allgather_comm.coll_counter;
req->to_send = req->num_packets;
req->to_recv = comm->commsize * req->num_packets;

comm->ag_counter++;
comm->allgather_comm.coll_counter++;
return UCC_OK;
}

Expand All @@ -329,7 +347,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_coll_do_allgather(void* sbuf, void

tl_trace(comm->lib, "MCAST allgather start, sbuf %p, rbuf %p, size %d, comm %d, "
"comm_size %d, counter %d",
sbuf, rbuf, size, comm->comm_id, comm->commsize, comm->ag_counter);
sbuf, rbuf, size, comm->comm_id, comm->commsize, comm->allgather_comm.coll_counter);

req = ucc_calloc(1, sizeof(ucc_tl_mlx5_mcast_coll_req_t), "mcast_req");
if (!req) {
Expand Down Expand Up @@ -387,9 +405,9 @@ void ucc_tl_mlx5_mcast_allgather_progress(ucc_coll_task_t *coll_task)

if (task->coll_mcast.req_handle != NULL) {
req = task->coll_mcast.req_handle;
if (req->ag_counter != req->comm->ag_under_progress_counter) {
if (req->ag_counter != req->comm->allgather_comm.under_progress_counter) {
/* it is not this task's turn for progress */
ucc_assert(req->comm->ag_under_progress_counter < req->ag_counter);
ucc_assert(req->comm->allgather_comm.under_progress_counter < req->ag_counter);
return;
}

Expand All @@ -398,7 +416,7 @@ void ucc_tl_mlx5_mcast_allgather_progress(ucc_coll_task_t *coll_task)
return;
} else if (UCC_OK == status) {
coll_task->status = UCC_OK;
req->comm->ag_under_progress_counter++;
req->comm->allgather_comm.under_progress_counter++;
ucc_free(req);
task->coll_mcast.req_handle = NULL;
} else {
Expand Down
6 changes: 0 additions & 6 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_allgather.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,6 @@
#include "tl_mlx5_mcast.h"
#include "tl_mlx5_coll.h"

#define MCAST_ALLGATHER_IN_PROGRESS(_req, _comm) \
(_req->to_send || _req->to_recv || _comm->pending_send || \
_comm->one_sided.rdma_read_in_progress || (NULL != _req->allgather_rkeys_req)) \

ucc_status_t ucc_tl_mlx5_mcast_allgather_init(ucc_tl_mlx5_task_t *task);

ucc_status_t ucc_tl_mlx5_mcast_test_allgather(ucc_tl_mlx5_mcast_coll_req_t* _req);

#endif
44 changes: 22 additions & 22 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_r_window_recycle(ucc_tl_mlx5_mcast_
ucc_tl_mlx5_mcast_coll_req_t *req)
{
ucc_status_t status = UCC_OK;
int wsize = comm->wsize;
int num_free_win = wsize - (comm->psn - comm->last_acked);
int wsize = comm->bcast_comm.wsize;
int num_free_win = wsize - (comm->psn - comm->bcast_comm.last_acked);
int req_completed = (req->to_send == 0 && req->to_recv == 0);
struct pp_packet *pp = NULL;

ucc_assert(comm->recv_drop_packet_in_progress == false);
ucc_assert(comm->bcast_comm.recv_drop_packet_in_progress == false);
ucc_assert(req->to_send >= 0);

/* When do we need to perform reliability protocol:
Expand All @@ -33,12 +33,12 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_r_window_recycle(ucc_tl_mlx5_mcast_
return status;
}

comm->n_mcast_reliable++;
comm->bcast_comm.n_mcast_reliable++;

for (;comm->last_acked < comm->psn; comm->last_acked++) {
pp = comm->r_window[comm->last_acked & (wsize-1)];
for (;comm->bcast_comm.last_acked < comm->psn; comm->bcast_comm.last_acked++) {
pp = comm->r_window[comm->bcast_comm.last_acked & (wsize-1)];
ucc_assert(pp != &comm->dummy_packet);
comm->r_window[comm->last_acked & (wsize-1)] = &comm->dummy_packet;
comm->r_window[comm->bcast_comm.last_acked & (wsize-1)] = &comm->dummy_packet;

pp->context = 0;
ucc_list_add_tail(&comm->bpool, &pp->super);
Expand All @@ -60,7 +60,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_bcast(ucc_tl_mlx5_mcast_coll_req
ucc_status_t status = UCC_OK;
ucc_tl_mlx5_mcast_coll_comm_t *comm = req->comm;
int zcopy = req->proto != MCAST_PROTO_EAGER;
int wsize = comm->wsize;
int wsize = comm->bcast_comm.wsize;
int num_free_win;
int num_sent;
int to_send;
Expand All @@ -74,29 +74,29 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_bcast(ucc_tl_mlx5_mcast_coll_req
return status;
}

if (ucc_unlikely(comm->recv_drop_packet_in_progress)) {
if (ucc_unlikely(comm->bcast_comm.recv_drop_packet_in_progress)) {
/* wait till parent resend the dropped packet */
return UCC_INPROGRESS;
}


if (req->to_send || req->to_recv) {
num_free_win = wsize - (comm->psn - comm->last_acked);
num_free_win = wsize - (comm->psn - comm->bcast_comm.last_acked);

/* Send data if i'm root and there is a space in the window */
if (num_free_win && req->am_root) {
num_sent = req->num_packets - req->to_send;
ucc_assert(req->to_send > 0);
ucc_assert(req->first_send_psn + num_sent < comm->last_acked + wsize);
if (req->first_send_psn + num_sent < comm->last_acked + wsize &&
ucc_assert(req->first_send_psn + num_sent < comm->bcast_comm.last_acked + wsize);
if (req->first_send_psn + num_sent < comm->bcast_comm.last_acked + wsize &&
req->to_send) {
/* How many to send: either all that are left (if they fit into window) or
up to the window limit */
to_send = ucc_min(req->to_send,
comm->last_acked + wsize - (req->first_send_psn + num_sent));
comm->bcast_comm.last_acked + wsize - (req->first_send_psn + num_sent));
ucc_tl_mlx5_mcast_send(comm, req, to_send, zcopy);

num_free_win = wsize - (comm->psn - comm->last_acked);
num_free_win = wsize - (comm->psn - comm->bcast_comm.last_acked);
}
}

Expand All @@ -119,8 +119,8 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_bcast(ucc_tl_mlx5_mcast_coll_req
tl_trace(comm->lib, "Did not receive the packet with psn in"
" current window range, so get ready for drop"
" event. pending_q_size %d current comm psn %d"
" last_acked psn %d stall threshold %d ",
pending_q_size, comm->psn, comm->last_acked,
" bcast_comm.last_acked psn %d stall threshold %d ",
pending_q_size, comm->psn, comm->bcast_comm.last_acked,
DROP_THRESHOLD);

status = ucc_tl_mlx5_mcast_bcast_check_drop(comm, req);
Expand All @@ -144,7 +144,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_bcast(ucc_tl_mlx5_mcast_coll_req
return status;
}

if (req->to_send || req->to_recv || (zcopy && comm->psn != comm->last_acked)) {
if (req->to_send || req->to_recv || (zcopy && comm->psn != comm->bcast_comm.last_acked)) {
return UCC_INPROGRESS;
} else {
return status;
Expand Down Expand Up @@ -201,16 +201,16 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_prepare_bcast(void* buf, size_t siz
}

req->offset = 0;
req->start_psn = comm->last_psn;
req->start_psn = comm->bcast_comm.last_psn;
req->num_packets = ucc_max(ucc_div_round_up(req->length, comm->max_per_packet), 1);
req->last_pkt_len = req->length - (req->num_packets - 1)*comm->max_per_packet;

ucc_assert(req->last_pkt_len > 0 && req->last_pkt_len <= comm->max_per_packet);

comm->last_psn += req->num_packets;
req->first_send_psn = req->start_psn;
req->to_send = req->am_root ? req->num_packets : 0;
req->to_recv = req->am_root ? 0 : req->num_packets;
comm->bcast_comm.last_psn += req->num_packets;
req->first_send_psn = req->start_psn;
req->to_send = req->am_root ? req->num_packets : 0;
req->to_recv = req->am_root ? 0 : req->num_packets;

return UCC_OK;
}
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm)
tl_debug(comm->lib, "comm_id %d, comm_size %d, comm->psn %d, rank %d, "
"nacks counter %d, n_mcast_rel %d",
comm->comm_id, comm->commsize, comm->psn, comm->rank,
comm->nacks_counter, comm->n_mcast_reliable);
comm->bcast_comm.nacks_counter, comm->bcast_comm.n_mcast_reliable);
}

if (comm->p2p_ctx != NULL) {
Expand Down
36 changes: 19 additions & 17 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_send(ucc_tl_mlx5_mcast_coll_comm_t
swr[0].imm_data = htonl(pp->psn);
swr[0].send_flags = (length <= comm->max_inline) ? IBV_SEND_INLINE : 0;

comm->r_window[pp->psn & (comm->wsize-1)] = pp;
comm->r_window[pp->psn & (comm->bcast_comm.wsize-1)] = pp;
comm->psn++;
req->to_send--;
offset += length;
Expand All @@ -102,7 +102,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_send(ucc_tl_mlx5_mcast_coll_comm_t
pp->psn, pp->length, zcopy, swr[0].send_flags & IBV_SEND_SIGNALED);

if (0 != (rc = ibv_post_send(comm->mcast.qp, &swr[0], &bad_wr))) {
tl_error(comm->lib, "Post send failed: ret %d, start_psn %d, to_send %d, "
tl_error(comm->lib, "post send failed: ret %d, start_psn %d, to_send %d, "
"to_recv %d, length %d, psn %d, inline %d",
rc, req->start_psn, req->to_send, req->to_recv,
length, pp->psn, length <= comm->max_inline);
Expand All @@ -127,7 +127,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_process_pp(ucc_tl_mlx5_mcast_coll_c
{
ucc_status_t status = UCC_OK;

if (PSN_RECEIVED(pp->psn, comm) || pp->psn < comm->last_acked) {
if (PSN_RECEIVED(pp->psn, comm) || pp->psn < comm->bcast_comm.last_acked) {
/* This psn was already received */
ucc_assert(pp->context == 0);
if (in_pending_queue) {
Expand Down Expand Up @@ -336,7 +336,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_send_collective(ucc_tl_mlx5_mcast_c
mcast_group_index);

if (0 != (rc = ibv_post_send(comm->mcast.qp_list[mcast_group_index], &swr[0], &bad_wr))) {
tl_error(comm->lib, "Post send failed: ret %d, start_psn %d, to_send %d, "
tl_error(comm->lib, "post send failed: ret %d, start_psn %d, to_send %d, "
"to_recv %d, length %d, psn %d, inline %d",
rc, req->start_psn, req->to_send, req->to_recv,
length, pp->psn, length <= comm->max_inline);
Expand Down Expand Up @@ -499,16 +499,17 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_reliable(ucc_tl_mlx5_mcast_coll_com
{
ucc_status_t status = UCC_OK;

if (comm->racks_n != comm->child_n || comm->sacks_n != comm->parent_n ||
comm->nack_requests) {
if (comm->bcast_comm.racks_n != comm->bcast_comm.child_n ||
comm->bcast_comm.sacks_n != comm->bcast_comm.parent_n ||
comm->bcast_comm.nack_requests) {
if (comm->pending_send) {
status = ucc_tl_mlx5_mcast_poll_send(comm);
if (UCC_OK != status) {
return status;
}
}

if (comm->parent_n) {
if (comm->bcast_comm.parent_n) {
status = ucc_tl_mlx5_mcast_poll_recv(comm);
if (UCC_OK != status) {
return status;
Expand All @@ -521,26 +522,27 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_reliable(ucc_tl_mlx5_mcast_coll_com
}
}

if (comm->parent_n && !comm->reliable_in_progress) {
if (comm->bcast_comm.parent_n && !comm->bcast_comm.reliable_in_progress) {
status = ucc_tl_mlx5_mcast_reliable_send(comm);
if (UCC_OK != status) {
return status;
}
}

if (!comm->reliable_in_progress) {
comm->reliable_in_progress = 1;
if (!comm->bcast_comm.reliable_in_progress) {
comm->bcast_comm.reliable_in_progress = 1;
}

if (comm->racks_n == comm->child_n && comm->sacks_n == comm->parent_n &&
0 == comm->nack_requests) {
if (comm->bcast_comm.racks_n == comm->bcast_comm.child_n &&
comm->bcast_comm.sacks_n == comm->bcast_comm.parent_n && 0 ==
comm->bcast_comm.nack_requests) {
// Reset for next round.
memset(comm->parents, 0, sizeof(comm->parents));
memset(comm->children, 0, sizeof(comm->children));
memset(comm->bcast_comm.parents, 0, sizeof(comm->bcast_comm.parents));
memset(comm->bcast_comm.children, 0, sizeof(comm->bcast_comm.children));

comm->racks_n = comm->child_n = 0;
comm->sacks_n = comm->parent_n = 0;
comm->reliable_in_progress = 0;
comm->bcast_comm.racks_n = comm->bcast_comm.child_n = 0;
comm->bcast_comm.sacks_n = comm->bcast_comm.parent_n = 0;
comm->bcast_comm.reliable_in_progress = 0;

return UCC_OK;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,11 @@ ucc_status_t ucc_tl_mlx5_mcast_process_packet_collective(ucc_tl_mlx5_mcast_coll_

if (comm->one_sided.recvd_pkts_tracker[source_rank] > req->num_packets) {
tl_error(comm->lib, "reliablity failed: comm->one_sided.recvd_pkts_tracker[%d] %d"
" req->num_packets %d offset %d PACKET_TO_DROP %d"
" comm->ag_under_progress_counter %d req->ag_counter"
" req->num_packets %d offset %d"
" comm->allgather_comm.under_progress_counter %d req->ag_counter"
" %d \n", source_rank, comm->one_sided.recvd_pkts_tracker[source_rank],
req->num_packets, offset, PACKET_TO_DROP,
comm->ag_under_progress_counter, req->ag_counter);
req->num_packets, offset,
comm->allgather_comm.under_progress_counter, req->ag_counter);
return UCC_ERR_NO_MESSAGE;
}
}
Expand Down
Loading

0 comments on commit 023835b

Please sign in to comment.