Skip to content

Commit

Permalink
TL/MLX5: fix memtype in bcast reliablity
Browse files Browse the repository at this point in the history
  • Loading branch information
MamziB committed Sep 27, 2024
1 parent 313f2da commit ff20b59
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 22 deletions.
21 changes: 12 additions & 9 deletions src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ void ucc_tl_mlx5_mcast_completion_cb(void* context, ucc_status_t status) //NOLIN

static inline ucc_status_t ucc_tl_mlx5_mcast_do_p2p_bcast_nb(void *buf, size_t
len, ucc_rank_t my_team_rank, ucc_rank_t dest,
ucc_memory_type_t mem_type,
ucc_team_h team, ucc_coll_callback_t *callback,
ucc_coll_req_h *p2p_req, int is_send,
ucc_base_lib_t *lib)
Expand All @@ -41,7 +42,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_p2p_bcast_nb(void *buf, size_t
args.src.info.buffer = buf;
args.src.info.count = len;
args.src.info.datatype = UCC_DT_INT8;
args.src.info.mem_type = UCC_MEMORY_TYPE_HOST;
args.src.info.mem_type = mem_type;
args.root = is_send ? my_team_rank : dest;
args.cb.cb = callback->cb;
args.cb.data = callback->data;
Expand Down Expand Up @@ -69,25 +70,27 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_p2p_bcast_nb(void *buf, size_t
}

static inline ucc_status_t do_send_nb(void *sbuf, size_t len, ucc_rank_t
my_team_rank, ucc_rank_t dest, ucc_team_h team,
my_team_rank, ucc_rank_t dest,
ucc_memory_type_t mem_type, ucc_team_h team,
ucc_coll_callback_t *callback,
ucc_coll_req_h *req, ucc_base_lib_t *lib)
{
return ucc_tl_mlx5_mcast_do_p2p_bcast_nb(sbuf, len, my_team_rank, dest,
return ucc_tl_mlx5_mcast_do_p2p_bcast_nb(sbuf, len, my_team_rank, dest, mem_type,
team, callback, req, 1, lib);
}

static inline ucc_status_t do_recv_nb(void *rbuf, size_t len, ucc_rank_t
my_team_rank, ucc_rank_t dest, ucc_team_h team,
my_team_rank, ucc_rank_t dest,
ucc_memory_type_t mem_type, ucc_team_h team,
ucc_coll_callback_t *callback,
ucc_coll_req_h *req, ucc_base_lib_t *lib)
{
return ucc_tl_mlx5_mcast_do_p2p_bcast_nb(rbuf, len, my_team_rank, dest,
return ucc_tl_mlx5_mcast_do_p2p_bcast_nb(rbuf, len, my_team_rank, dest, mem_type,
team, callback, req, 0, lib);
}

ucc_status_t ucc_tl_mlx5_mcast_p2p_send_nb(void* src, size_t size, ucc_rank_t
rank, void *context,
rank, ucc_memory_type_t mem_type, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t
*obj)
{
Expand All @@ -103,7 +106,7 @@ ucc_status_t ucc_tl_mlx5_mcast_p2p_send_nb(void* src, size_t size, ucc_rank_t
callback.data = obj;

tl_trace(oob_p2p_ctx->lib, "P2P: SEND to %d Msg Size %ld", rank, size);
status = do_send_nb(src, size, my_team_rank, rank, team, &callback, &req, oob_p2p_ctx->lib);
status = do_send_nb(src, size, my_team_rank, rank, mem_type, team, &callback, &req, oob_p2p_ctx->lib);

if (status < 0) {
tl_error(oob_p2p_ctx->lib, "nonblocking p2p send failed");
Expand All @@ -114,7 +117,7 @@ ucc_status_t ucc_tl_mlx5_mcast_p2p_send_nb(void* src, size_t size, ucc_rank_t
}

ucc_status_t ucc_tl_mlx5_mcast_p2p_recv_nb(void *dst, size_t size, ucc_rank_t
rank, void *context,
rank, ucc_memory_type_t mem_type, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t
*obj)
{
Expand All @@ -130,7 +133,7 @@ ucc_status_t ucc_tl_mlx5_mcast_p2p_recv_nb(void *dst, size_t size, ucc_rank_t
callback.data = obj;

tl_trace(oob_p2p_ctx->lib, "P2P: RECV to %d Msg Size %ld", rank, size);
status = do_recv_nb(dst, size, my_team_rank, rank, team, &callback, &req, oob_p2p_ctx->lib);
status = do_recv_nb(dst, size, my_team_rank, rank, mem_type, team, &callback, &req, oob_p2p_ctx->lib);

if (status < 0) {
tl_error(oob_p2p_ctx->lib, "nonblocking p2p recv failed");
Expand Down
4 changes: 2 additions & 2 deletions src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
#include "components/tl/mlx5/mcast/tl_mlx5_mcast.h"

ucc_status_t ucc_tl_mlx5_mcast_p2p_send_nb(void* src, size_t size, ucc_rank_t
rank, void *context,
rank, ucc_memory_type_t mem_type, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t
*obj);

ucc_status_t ucc_tl_mlx5_mcast_p2p_recv_nb(void* dst, size_t size, ucc_rank_t
rank, void *context,
rank, ucc_memory_type_t mem_type, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t
*obj);

Expand Down
4 changes: 2 additions & 2 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,12 @@ typedef struct ucc_tl_mlx5_mcast_p2p_completion_obj {
typedef int (*ucc_tl_mlx5_mcast_p2p_wait_cb_fn_t)(void *wait_arg);

typedef ucc_status_t (*ucc_tl_mlx5_mcast_p2p_send_nb_fn_t)(void* src, size_t size,
ucc_rank_t rank, void *context,
ucc_rank_t rank, ucc_memory_type_t mem_type, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t *compl_obj);


typedef ucc_status_t (*ucc_tl_mlx5_mcast_p2p_recv_nb_fn_t)(void* src, size_t size,
ucc_rank_t rank, void *context,
ucc_rank_t rank, ucc_memory_type_t mem_type, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t *compl_obj);

typedef struct ucc_tl_mlx5_mcast_p2p_interface {
Expand Down
4 changes: 3 additions & 1 deletion src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,9 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_prepare_bcast(void* buf, size_t siz
req->am_root = (root == comm->rank);
req->mr = comm->pp_mr;
req->rreg = NULL;
req->proto = (req->length < comm->max_eager) ? MCAST_PROTO_EAGER : MCAST_PROTO_ZCOPY;
/* cost of copy is too high in cuda buffers */
req->proto = (req->length < comm->max_eager && !comm->cuda_mem_enabled) ?
MCAST_PROTO_EAGER : MCAST_PROTO_ZCOPY;

status = ucc_tl_mlx5_mcast_prepare_reliable(comm, req, req->root);
if (ucc_unlikely(UCC_OK != status)) {
Expand Down
45 changes: 37 additions & 8 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ static ucc_status_t ucc_tl_mlx5_mcast_reliability_send_completion(ucc_tl_mlx5_mc
comm->nack_requests--;
status = comm->params.p2p_iface.recv_nb(&comm->p2p_pkt[pkt_id],
sizeof(struct packet), comm->p2p_pkt[pkt_id].from,
UCC_MEMORY_TYPE_HOST,
comm->p2p_ctx, GET_COMPL_OBJ(comm,
ucc_tl_mlx5_mcast_recv_completion, pkt_id, NULL));
if (status < 0) {
Expand All @@ -48,6 +49,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_resend_packet_reliable(ucc_tl_mlx5_
uint32_t psn = comm->p2p_pkt[p2p_pkt_id].psn;
struct pp_packet *pp = comm->r_window[psn % comm->wsize];
ucc_status_t status;
ucc_memory_type_t mem_type;

ucc_assert(pp->psn == psn);
ucc_assert(comm->p2p_pkt[p2p_pkt_id].type == MCAST_P2P_NEED_NACK_SEND);
Expand All @@ -58,8 +60,14 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_resend_packet_reliable(ucc_tl_mlx5_
comm->comm_id, comm->rank,
comm->p2p_pkt[p2p_pkt_id].from, psn, pp->context, comm->nack_requests);

if (comm->cuda_mem_enabled) {
mem_type = UCC_MEMORY_TYPE_CUDA;
} else {
mem_type = UCC_MEMORY_TYPE_HOST;
}

status = comm->params.p2p_iface.send_nb((void*) (pp->context ? pp->context : pp->buf),
pp->length, comm->p2p_pkt[p2p_pkt_id].from,
pp->length, comm->p2p_pkt[p2p_pkt_id].from, mem_type,
comm->p2p_ctx, GET_COMPL_OBJ(comm,
ucc_tl_mlx5_mcast_reliability_send_completion, NULL, p2p_pkt_id));
if (status < 0) {
Expand Down Expand Up @@ -138,11 +146,25 @@ static ucc_status_t ucc_tl_mlx5_mcast_recv_data_completion(ucc_tl_mlx5_mcast_p2p
struct pp_packet *pp = (struct pp_packet *)obj->data[1];
ucc_tl_mlx5_mcast_coll_req_t *req = (ucc_tl_mlx5_mcast_coll_req_t *)obj->data[2];
void *dest;
ucc_memory_type_t mem_type;

tl_trace(comm->lib, "[comm %d, rank %d] Recved data psn %d", comm->comm_id, comm->rank, pp->psn);

dest = req->ptr + PSN_TO_RECV_OFFSET(pp->psn, req, comm);
memcpy(dest, (void*) pp->buf, pp->length);

if (comm->cuda_mem_enabled) {
mem_type = UCC_MEMORY_TYPE_CUDA;
} else {
mem_type = UCC_MEMORY_TYPE_HOST;
}

status = ucc_mc_memcpy(dest, (void*) pp->buf, pp->length,
mem_type, mem_type);
if (ucc_unlikely(status != UCC_OK)) {
tl_error(comm->lib, "failed to copy buffer");
return status;
}

req->to_recv--;
comm->r_window[pp->psn % comm->wsize] = pp;

Expand All @@ -165,6 +187,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_reliable_send_NACK(ucc_tl_mlx5_mcas
struct pp_packet *pp;
ucc_rank_t parent;
struct packet *p;
ucc_memory_type_t mem_type;

p = ucc_calloc(1, sizeof(struct packet));
p->type = MCAST_P2P_NACK;
Expand All @@ -176,7 +199,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_reliable_send_NACK(ucc_tl_mlx5_mcas

comm->nacks_counter++;

status = comm->params.p2p_iface.send_nb(p, sizeof(struct packet), parent,
status = comm->params.p2p_iface.send_nb(p, sizeof(struct packet), parent, UCC_MEMORY_TYPE_HOST,
comm->p2p_ctx, GET_COMPL_OBJ(comm,
ucc_tl_mlx5_mcast_reliability_send_completion, p, UINT_MAX));
if (status < 0) {
Expand All @@ -193,8 +216,14 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_reliable_send_NACK(ucc_tl_mlx5_mcas

comm->recv_drop_packet_in_progress = true;

if (comm->cuda_mem_enabled) {
mem_type = UCC_MEMORY_TYPE_CUDA;
} else {
mem_type = UCC_MEMORY_TYPE_HOST;
}

status = comm->params.p2p_iface.recv_nb((void*) pp->buf,
pp->length, parent,
pp->length, parent, mem_type,
comm->p2p_ctx, GET_COMPL_OBJ(comm,
ucc_tl_mlx5_mcast_recv_data_completion, pp, req));
if (status < 0) {
Expand Down Expand Up @@ -225,7 +254,7 @@ ucc_status_t ucc_tl_mlx5_mcast_reliable_send(ucc_tl_mlx5_mcast_coll_comm_t *comm
comm->rank, parent, comm->parent_n, comm->psn);

status = comm->params.p2p_iface.send_nb(&comm->p2p_spkt[i],
sizeof(struct packet), parent,
sizeof(struct packet), parent, UCC_MEMORY_TYPE_HOST,
comm->p2p_ctx, GET_COMPL_OBJ(comm,
ucc_tl_mlx5_mcast_send_completion, i, NULL));
if (status < 0) {
Expand Down Expand Up @@ -325,7 +354,7 @@ ucc_status_t ucc_tl_mlx5_mcast_prepare_reliable(ucc_tl_mlx5_mcast_coll_comm_t *c
comm->rank, child, comm->child_n, comm->psn);

status = comm->params.p2p_iface.recv_nb(&comm->p2p_pkt[comm->child_n - 1],
sizeof(struct packet), child,
sizeof(struct packet), child, UCC_MEMORY_TYPE_HOST,
comm->p2p_ctx, GET_COMPL_OBJ(comm,
ucc_tl_mlx5_mcast_recv_completion, comm->child_n - 1, req));
if (status < 0) {
Expand Down Expand Up @@ -369,8 +398,8 @@ ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *com
ucc_tl_mlx5_mcast_coll_req_t *req,
struct pp_packet* pp)
{
ucc_status_t status = UCC_OK;
void *dest;
ucc_status_t status = UCC_OK;
void *dest;
ucc_memory_type_t mem_type;
ucc_assert(pp->psn >= req->start_psn &&
pp->psn < req->start_psn + req->num_packets);
Expand Down

0 comments on commit ff20b59

Please sign in to comment.