diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h index cab8046281..66da5ff474 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h @@ -16,6 +16,7 @@ #include "components/tl/ucc_tl.h" #include "components/tl/ucc_tl_log.h" #include "utils/ucc_rcache.h" +#include "core/ucc_service_coll.h" #define POLL_PACKED 16 #define REL_DONE ((void*)-1) @@ -119,6 +120,7 @@ typedef struct ucc_tl_mlx5_mcast_rcache_region { } ucc_tl_mlx5_mcast_rcache_region_t; typedef struct ucc_tl_mlx5_mcast_ctx_params { + int mcast_enabled; char *ib_dev_name; int print_nack_stats; int timeout; @@ -142,11 +144,19 @@ typedef struct ucc_tl_mlx5_mcast_coll_context { ucc_base_lib_t *lib; } ucc_tl_mlx5_mcast_coll_context_t; +typedef struct ucc_tl_mlx5_mcast_join_info_t { + ucc_status_t status; + uint16_t dlid; + union ibv_gid dgid; +} ucc_tl_mlx5_mcast_join_info_t; + typedef struct ucc_tl_mlx5_mcast_context { ucc_thread_mode_t tm; ucc_tl_mlx5_mcast_coll_context_t mcast_context; ucc_tl_mlx5_mcast_context_config_t cfg; ucc_mpool_t req_mp; + int mcast_enabled; + int mcast_ready; ucc_tl_mlx5_mcast_oob_ctx_t oob_ctx; } ucc_tl_mlx5_mcast_context_t; @@ -225,6 +235,11 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm { int n_prep_reliable; int n_mcast_reliable; int wsize; + ucc_tl_mlx5_mcast_join_info_t *group_setup_info; + ucc_service_coll_req_t *group_setup_info_req; + ucc_status_t (*bcast_post) (void*, void*, size_t, ucc_rank_t, ucc_service_coll_req_t**); + ucc_status_t (*bcast_test) (ucc_service_coll_req_t*); + struct rdma_cm_event *event; struct pp_packet *r_window[1]; // do not add any new variable after here } ucc_tl_mlx5_mcast_coll_comm_t; @@ -352,6 +367,8 @@ ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *tl_context, const ucc_base_team_params_t *params, ucc_tl_mlx5_mcast_coll_comm_init_spec_t *mcast_conf); +ucc_status_t ucc_tl_mlx5_mcast_team_test(ucc_base_team_t *team); + ucc_status_t ucc_tl_mlx5_mcast_coll_init(ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, ucc_coll_task_t **task_h); @@ -359,4 +376,6 @@ ucc_status_t ucc_tl_mlx5_mcast_coll_init(ucc_base_coll_args_t *coll_args, ucc_status_t ucc_tl_mlx5_mcast_context_init(ucc_tl_mlx5_mcast_context_t *mcast_ctx, ucc_tl_mlx5_mcast_ctx_params_t *mcast_ctx_conf); + +ucc_status_t ucc_tl_mlx5_mcast_clean_ctx(ucc_tl_mlx5_mcast_coll_context_t *ctx); #endif diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_context.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_context.c index ad32c459b0..5361f1deb5 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_context.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_context.c @@ -51,12 +51,20 @@ ucc_status_t ucc_tl_mlx5_mcast_context_init(ucc_tl_mlx5_mcast_context_t *cont int ib_valid; const char *dst; + mlx5_ctx = ucc_container_of(context, ucc_tl_mlx5_context_t, mcast); + lib = mlx5_ctx->super.super.lib; + + context->mcast_enabled = mcast_ctx_conf->mcast_enabled; + + if (!mcast_ctx_conf->mcast_enabled) { + tl_debug(lib, "Mcast is disabled by the user"); + return UCC_ERR_NO_RESOURCE; + } + ctx = &(context->mcast_context); memset(ctx, 0, sizeof(ucc_tl_mlx5_mcast_coll_context_t)); memcpy(&ctx->params, mcast_ctx_conf, sizeof(ucc_tl_mlx5_mcast_ctx_params_t)); - mlx5_ctx = ucc_container_of(context, ucc_tl_mlx5_context_t, mcast); - lib = mlx5_ctx->super.super.lib; ctx->lib = lib; /* TODO unify all the contexts under TL mlx5 */ @@ -239,13 +247,55 @@ ucc_status_t ucc_tl_mlx5_mcast_context_init(ucc_tl_mlx5_mcast_context_t *cont error: if (ctx->pd) { ibv_dealloc_pd(ctx->pd); + ctx->pd = NULL; } if (ctx->id) { rdma_destroy_id(ctx->id); + ctx->id = NULL; } if (ctx->channel) { rdma_destroy_event_channel(ctx->channel); + ctx->channel = NULL; } return status; } + +ucc_status_t ucc_tl_mlx5_mcast_clean_ctx(ucc_tl_mlx5_mcast_coll_context_t *ctx) +{ + tl_debug(ctx->lib, "cleaning mcast ctx: %p", ctx); + + if (ctx == NULL) return UCC_OK; + + if (ctx->rcache) { + ucc_rcache_destroy(ctx->rcache); + ctx->rcache = NULL; + } + + if (ctx->pd) { + if (ibv_dealloc_pd(ctx->pd)) { + tl_error(ctx->lib, "ibv_dealloc_pd failed errno %d", errno); + return UCC_ERR_NO_RESOURCE; + } + ctx->pd = NULL; + } + + if (ctx->id && rdma_destroy_id(ctx->id)) { + tl_error(ctx->lib, "rdma_destroy_id failed errno %d", errno); + return UCC_ERR_NO_RESOURCE; + } + + ctx->id = NULL; + + if (ctx->channel) { + rdma_destroy_event_channel(ctx->channel); + ctx->channel = NULL; + } + + if (ctx->devname && !strcmp(ctx->params.ib_dev_name, "")) { + ucc_free(ctx->devname); + ctx->devname = NULL; + } + + return UCC_OK; +} diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c index 8c52a63c73..de3af55a60 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c @@ -529,33 +529,3 @@ ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm) return UCC_OK; } -ucc_status_t ucc_tl_mlx5_clean_mcast_ctx(ucc_tl_mlx5_mcast_coll_context_t *ctx) -{ - tl_debug(ctx->lib, "cleaning mcast ctx: %p", ctx); - - if (ctx->rcache) { - ucc_rcache_destroy(ctx->rcache); - } - - if (ctx->pd) { - if (ibv_dealloc_pd(ctx->pd)) { - tl_error(ctx->lib, "ibv_dealloc_pd failed errno %d", errno); - return UCC_ERR_NO_RESOURCE; - } - } - - if (rdma_destroy_id(ctx->id)) { - tl_error(ctx->lib, "rdma_destroy_id failed errno %d", errno); - return UCC_ERR_NO_RESOURCE; - } - - rdma_destroy_event_channel(ctx->channel); - - if (!strcmp(ctx->params.ib_dev_name, "")) { - ucc_free(ctx->devname); - } - - ucc_free(ctx); - - return UCC_OK; -} diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h index 05037e495f..bd3e7521fb 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h @@ -365,4 +365,12 @@ ucc_status_t ucc_tl_mlx5_mcast_setup_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx, ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm); +ucc_status_t ucc_tl_mlx5_mcast_join_mcast_post(ucc_tl_mlx5_mcast_coll_context_t *ctx, + struct sockaddr_in6 *net_addr, + int is_root); + +ucc_status_t ucc_tl_mlx5_mcast_join_mcast_test(ucc_tl_mlx5_mcast_coll_context_t *ctx, + struct rdma_cm_event **event, + int is_root); + #endif /* TL_MLX5_MCAST_HELPER_H_ */ diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c index f56bc3c1a1..6cac983bf0 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c @@ -9,14 +9,184 @@ #include "tl_mlx5_mcast_coll.h" #include "coll_score/ucc_coll_score.h" #include "tl_mlx5_mcast_helper.h" +#include "p2p/ucc_tl_mlx5_mcast_p2p.h" +#include "mcast/tl_mlx5_mcast_helper.h" + +static ucc_status_t ucc_tl_mlx5_mcast_service_bcast_post(void *arg, void *buf, size_t size, ucc_rank_t root, + ucc_service_coll_req_t **bcast_req) +{ + ucc_tl_mlx5_mcast_oob_p2p_context_t *ctx = (ucc_tl_mlx5_mcast_oob_p2p_context_t *)arg; + ucc_status_t status = UCC_OK; + ucc_team_t *team = ctx->base_team; + ucc_subset_t subset = ctx->subset; + ucc_service_coll_req_t *req = NULL; + + status = ucc_service_bcast(team, buf, size, root, subset, &req); + if (ucc_unlikely(UCC_OK != status)) { + tl_error(ctx->base_ctx->lib, "tl service mcast bcast failed"); + return status; + } + + *bcast_req = req; + + return status; +} -ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context, /* NOLINT */ - ucc_tl_mlx5_mcast_team_t **mcast_team, /* NOLINT */ - ucc_tl_mlx5_mcast_context_t *ctx, /* NOLINT */ - const ucc_base_team_params_t *params, /* NOLINT */ - ucc_tl_mlx5_mcast_coll_comm_init_spec_t *mcast_conf /* NOLINT */) +static ucc_status_t ucc_tl_mlx5_mcast_service_bcast_test(ucc_service_coll_req_t *req) { + ucc_status_t status = UCC_OK; + + status = ucc_service_coll_test(req); + + if (UCC_INPROGRESS != status) { + ucc_service_coll_finalize(req); + } + + return status; +} + +ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context, + ucc_tl_mlx5_mcast_team_t **mcast_team, + ucc_tl_mlx5_mcast_context_t *ctx, + const ucc_base_team_params_t *team_params, + ucc_tl_mlx5_mcast_coll_comm_init_spec_t *mcast_conf) +{ + ucc_status_t status; + ucc_subset_t set; + ucc_tl_mlx5_mcast_coll_comm_init_spec_t comm_spec = *mcast_conf; + ucc_tl_mlx5_mcast_coll_context_t *mcast_context = &(ctx->mcast_context); + ucc_tl_mlx5_mcast_coll_comm_init_spec_t *conf_params = &comm_spec; + ucc_context_t *context = base_context->ucc_context; + ucc_tl_mlx5_mcast_team_t *new_mcast_team; + ucc_tl_mlx5_mcast_oob_p2p_context_t *oob_p2p_ctx; + ucc_tl_mlx5_mcast_coll_comm_t *comm; + int i; + + if (!ctx->mcast_enabled || NULL == mcast_context) { + tl_debug(base_context->lib, + "mcast context not available, base_context = %p", + base_context ); + return UCC_ERR_NO_RESOURCE; + } + + new_mcast_team = ucc_calloc(1, sizeof(ucc_tl_mlx5_mcast_team_t), "new_mcast_team"); + + if (!new_mcast_team) { + return UCC_ERR_NO_MEMORY; + } + + new_mcast_team->mcast_context = ctx; + + /* init p2p interface */ + conf_params->p2p_iface.send_nb = ucc_tl_mlx5_mcast_p2p_send_nb; + conf_params->p2p_iface.recv_nb = ucc_tl_mlx5_mcast_p2p_recv_nb; + + oob_p2p_ctx = ucc_malloc(sizeof(ucc_tl_mlx5_mcast_oob_p2p_context_t), + "oob_p2p_ctx"); + if (!oob_p2p_ctx) { + ucc_free(new_mcast_team); + return UCC_ERR_NO_MEMORY; + } + + oob_p2p_ctx->base_ctx = context; + oob_p2p_ctx->base_team = team_params->team; + oob_p2p_ctx->my_team_rank = team_params->rank; + set.myrank = team_params->rank; + set.map = team_params->map; + oob_p2p_ctx->subset = set; + conf_params->oob = oob_p2p_ctx; + conf_params->sx_sge = 1; + conf_params->rx_sge = 2; + conf_params->scq_moderation = 64; + + comm = (ucc_tl_mlx5_mcast_coll_comm_t*) + ucc_calloc(1, sizeof(ucc_tl_mlx5_mcast_coll_comm_t) + + sizeof(struct pp_packet*)*(conf_params->wsize-1), + "ucc_tl_mlx5_mcast_coll_comm_t"); + if (!comm) { + ucc_free(oob_p2p_ctx); + ucc_free(new_mcast_team); + return UCC_ERR_NO_MEMORY; + } + + ucc_list_head_init(&comm->bpool); + ucc_list_head_init(&comm->pending_q); + + comm->bcast_post = ucc_tl_mlx5_mcast_service_bcast_post; + comm->bcast_test = ucc_tl_mlx5_mcast_service_bcast_test; + + memcpy(&comm->params, conf_params, sizeof(*conf_params)); + + comm->wsize = conf_params->wsize; + comm->max_eager = conf_params->max_eager; + comm->comm_id = team_params->id; + comm->ctx = mcast_context; + comm->grh_buf = (char *)ucc_malloc(GRH_LENGTH * sizeof(char), "grh_buf"); + if (!comm->grh_buf) { + status = UCC_ERR_NO_MEMORY; + goto cleanup; + } + + memset(comm->grh_buf, 0, GRH_LENGTH); + + comm->grh_mr = ibv_reg_mr(mcast_context->pd, comm->grh_buf, GRH_LENGTH, + IBV_ACCESS_REMOTE_WRITE | + IBV_ACCESS_LOCAL_WRITE); + if (!comm->grh_mr) { + tl_error(mcast_context->lib, "could not register memory for GRH, errno %d", errno); + status = UCC_ERR_NO_RESOURCE; + goto cleanup; + } + + comm->rcq = ibv_create_cq(mcast_context->ctx, comm->params.rx_depth, NULL, NULL, 0); + if (!comm->rcq) { + ibv_dereg_mr(comm->grh_mr); + tl_error(mcast_context->lib, "could not create recv cq, rx_depth %d, errno %d", + comm->params.rx_depth, errno); + status = UCC_ERR_NO_RESOURCE; + goto cleanup; + } + + comm->scq = ibv_create_cq(mcast_context->ctx, comm->params.sx_depth, NULL, NULL, 0); + if (!comm->scq) { + ibv_dereg_mr(comm->grh_mr); + ibv_destroy_cq(comm->rcq); + tl_error(mcast_context->lib, "could not create send cq, sx_depth %d, errno %d", + comm->params.sx_depth, errno); + status = UCC_ERR_NO_RESOURCE; + goto cleanup; + } + + comm->rank = team_params->rank; + comm->commsize = team_params->size; + comm->max_per_packet = mcast_context->mtu - GRH_LENGTH; + comm->last_acked = comm->last_psn = 0; + comm->racks_n = comm->sacks_n = 0; + comm->child_n = comm->parent_n = 0; + comm->p2p_ctx = conf_params->oob; + + memcpy(&comm->p2p, &conf_params->p2p_iface, + sizeof(ucc_tl_mlx5_mcast_p2p_interface_t)); + + comm->dummy_packet.psn = UINT32_MAX; + + for (i=0; i< comm->wsize; i++) { + comm->r_window[i] = &comm->dummy_packet; + } + + comm->lib = base_context->lib; + new_mcast_team->mcast_comm = comm; + *mcast_team = new_mcast_team; + + tl_debug(base_context->lib, "posted tl mcast team : %p", new_mcast_team); + return UCC_OK; + +cleanup: + ucc_free(comm); + ucc_free(new_mcast_team); + ucc_free(oob_p2p_ctx); + return status; } ucc_status_t ucc_tl_mlx5_mcast_coll_setup_comm_resources(ucc_tl_mlx5_mcast_coll_comm_t *comm) @@ -128,3 +298,295 @@ ucc_status_t ucc_tl_mlx5_mcast_coll_setup_comm_resources(ucc_tl_mlx5_mcast_coll_ ucc_tl_mlx5_clean_mcast_comm(comm); return status; } + +ucc_status_t ucc_tl_mlx5_mcast_team_test(ucc_base_team_t *team) +{ + ucc_tl_mlx5_team_t *tl_team = ucc_derived_of(team, ucc_tl_mlx5_team_t); + ucc_tl_mlx5_mcast_coll_comm_t *comm = tl_team->mcast->mcast_comm; + ucc_status_t status = UCC_OK; + struct sockaddr_in6 net_addr = {0,}; + ucc_tl_mlx5_mcast_join_info_t *data = NULL; + + if (comm->rank == 0) { + switch(tl_team->mcast_state) { + case TL_MLX5_TEAM_STATE_MCAST_INIT: + { + /* now it is time for rank 0 to call rdma_join_multicast() */ + net_addr.sin6_family = AF_INET6; + net_addr.sin6_flowinfo = comm->comm_id; + status = ucc_tl_mlx5_mcast_join_mcast_post(comm->ctx, &net_addr, 1); + if (status < 0) { + tl_error(comm->lib, "rank 0 is unable to join mcast group error %d", status); + tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_FAILED; + return UCC_INPROGRESS; + } + + comm->mcast_addr = net_addr; + tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_POST; + + return UCC_INPROGRESS; + } + + case TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_POST: + { + /* rank 0 has already called rdma_join_multicast() + * it is time to wait for the rdma event to confirm the join */ + status = ucc_tl_mlx5_mcast_join_mcast_test(comm->ctx, &comm->event, 1); + if (UCC_OK != status) { + if (status < 0) { + tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_FAILED; + } + if (comm->event) { + if (rdma_ack_cm_event(comm->event) < 0) { + tl_error(comm->lib, "rdma_ack_cm_event failed"); + return UCC_ERR_NO_RESOURCE; + } + comm->event = NULL; + } + return UCC_INPROGRESS; + } + + ucc_assert(comm->event != NULL); + + /* at this point, rank 0 has joined mcast group */ + tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_READY; + + return UCC_INPROGRESS; + } + + case TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_READY: + case TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_FAILED: + { + + data = ucc_calloc(1, sizeof(ucc_tl_mlx5_mcast_join_info_t), + "ucc_tl_mlx5_mcast_join_info_t"); + if (!data) { + tl_error(comm->lib, "unable to allocate memory for group setup info"); + return UCC_ERR_NO_MEMORY; + } + + comm->group_setup_info = data; + + if (tl_team->mcast_state == TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_READY) { + /* rank 0 bcast the lid/gid to other processes */ + data->status = UCC_OK; + data->dgid = comm->event->param.ud.ah_attr.grh.dgid; + data->dlid = comm->event->param.ud.ah_attr.dlid; + comm->mcast_lid = data->dlid; + comm->mgid = data->dgid; + } else { + /* rank 0 bcast the failed status to other processes so others do not hang */ + data->status = UCC_ERR_NO_RESOURCE; + } + + status = comm->bcast_post(comm->p2p_ctx, data, sizeof(ucc_tl_mlx5_mcast_join_info_t), + 0, &comm->group_setup_info_req); + if (UCC_OK != status) { + tl_error(comm->lib, "unable to post bcast for group setup info"); + ucc_free(comm->group_setup_info); + if (comm->event) { + if (rdma_ack_cm_event(comm->event) < 0) { + tl_error(comm->lib, "rdma_ack_cm_event failed"); + return UCC_ERR_NO_RESOURCE; + } + comm->event = NULL; + } + return status; + } + + tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_GRP_BCAST_POST; + + return UCC_INPROGRESS; + } + + case TL_MLX5_TEAM_STATE_MCAST_GRP_BCAST_POST: + { + /* rank 0 polls bcast request and wait for its completion */ + status = comm->bcast_test(comm->group_setup_info_req); + if (UCC_OK != status) { + /* bcast is not completed yet */ + if (status < 0) { + if (rdma_ack_cm_event(comm->event) < 0) { + tl_error(comm->lib, "rdma_ack_cm_event failed"); + } + ucc_free(comm->group_setup_info); + } + return status; + } + + if (comm->group_setup_info->status != UCC_OK) { + /* rank 0 was not able to join a mcast group so all + * the ranks should return */ + if (rdma_ack_cm_event(comm->event) < 0) { + tl_error(comm->lib, "rdma_ack_cm_event failed"); + } + ucc_free(comm->group_setup_info); + return UCC_ERR_NO_RESOURCE; + } + + ucc_free(comm->group_setup_info); + if (comm->event) { + if (rdma_ack_cm_event(comm->event) < 0) { + tl_error(comm->lib, "rdma_ack_cm_event failed"); + return UCC_ERR_NO_RESOURCE; + } + comm->event = NULL; + } + + /* setup of the rest of the mcast resources */ + status = ucc_tl_mlx5_mcast_coll_setup_comm_resources(comm); + if (UCC_OK != status) { + return status; + } + + tl_debug(comm->lib, "initialized tl mcast team: %p", tl_team); + tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_READY; + + return UCC_INPROGRESS; + } + + case TL_MLX5_TEAM_STATE_MCAST_READY: + case TL_MLX5_TEAM_STATE_MCAST_NOT_AVAILABLE: + { + return UCC_OK; + } + + default: + { + tl_error(comm->lib, "unknown state during mcast team: %p create", tl_team); + return UCC_ERR_NO_RESOURCE; + } + } + } else { + /* none rank 0 team create states */ + switch(tl_team->mcast_state) { + case TL_MLX5_TEAM_STATE_MCAST_INIT: + { + /* none 0 ranks bcast post to wait for rank 0 for lid/gid + * of the mcast group */ + data = ucc_calloc(1, sizeof(ucc_tl_mlx5_mcast_join_info_t), + "ucc_tl_mlx5_mcast_join_info_t"); + if (!data) { + tl_error(comm->lib, "unable to allocate memory for group setup info"); + return UCC_ERR_NO_MEMORY; + } + + status = comm->bcast_post(comm->p2p_ctx, data, sizeof(ucc_tl_mlx5_mcast_join_info_t), + 0, &comm->group_setup_info_req); + if (UCC_OK != status) { + tl_error(comm->lib, "unable to post bcast for group setup info"); + ucc_free(data); + return status; + } + + comm->group_setup_info = data; + tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_GRP_BCAST_POST; + + return UCC_INPROGRESS; + } + + case TL_MLX5_TEAM_STATE_MCAST_GRP_BCAST_POST: + { + /* none rank 0 processes poll bcast request and wait for its completion */ + status = comm->bcast_test(comm->group_setup_info_req); + if (UCC_OK != status) { + /* bcast is not completed yet */ + if (status < 0) { + ucc_free(comm->group_setup_info); + } + return status; + } + + data = comm->group_setup_info; + status = data->status; + if (UCC_OK != status) { + /* rank 0 was not able to join a mcast group so all + * the ranks should return */ + ucc_free(data); + return status; + } + + /* now it is time for none rank 0 to call rdma_join_multicast() */ + memcpy(&net_addr.sin6_addr, &(data->dgid), sizeof(struct in6_addr)); + net_addr.sin6_family = AF_INET6; + + status = ucc_tl_mlx5_mcast_join_mcast_post(comm->ctx, &net_addr, 0); + if (status < 0) { + tl_error(comm->lib, "none-root rank is unable to join mcast group error %d", status); + ucc_free(data); + return status; + } + + comm->mcast_addr = net_addr; + tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_POST; + + return UCC_INPROGRESS; + } + + case TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_POST: + { + /* none-root rank has already called rdma_join_multicast() + * it is time to wait for the rdma event to confirm the join */ + status = ucc_tl_mlx5_mcast_join_mcast_test(comm->ctx, &comm->event, 0); + if (UCC_OK != status) { + if (comm->event) { + if (rdma_ack_cm_event(comm->event) < 0) { + tl_error(comm->lib, "rdma_ack_cm_event failed"); + return UCC_ERR_NO_RESOURCE; + } + comm->event = NULL; + } + if (status < 0) { + ucc_free(comm->group_setup_info); + } + return status; + } + + ucc_assert(comm->event != NULL); + + comm->mcast_lid = comm->group_setup_info->dlid; + comm->mgid = comm->group_setup_info->dgid; + + ucc_free(comm->group_setup_info); + if (comm->event) { + if (rdma_ack_cm_event(comm->event) < 0) { + tl_error(comm->lib, "rdma_ack_cm_event failed"); + return UCC_ERR_NO_RESOURCE; + } + comm->event = NULL; + } + + /* at this point, none-root rank has joined mcast group */ + tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_READY; + + return UCC_INPROGRESS; + } + + case TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_READY: + { + /* setup of the rest of the mcast resources */ + status = ucc_tl_mlx5_mcast_coll_setup_comm_resources(comm); + if (UCC_OK != status) { + return status; + } + + tl_debug(comm->lib, "initialized tl mcast team: %p", tl_team); + tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_READY; + + return UCC_INPROGRESS; + } + + case TL_MLX5_TEAM_STATE_MCAST_READY: + case TL_MLX5_TEAM_STATE_MCAST_NOT_AVAILABLE: + { + return UCC_OK; + } + + default: + { + tl_error(comm->lib, "unknown state during mcast team: %p create", tl_team); + return UCC_ERR_NO_RESOURCE; + } + } + } +} diff --git a/src/components/tl/mlx5/tl_mlx5.c b/src/components/tl/mlx5/tl_mlx5.c index 0210f2302c..3e1f24f4ca 100644 --- a/src/components/tl/mlx5/tl_mlx5.c +++ b/src/components/tl/mlx5/tl_mlx5.c @@ -102,6 +102,10 @@ static ucc_config_field_t ucc_tl_mlx5_context_config_table[] = { ucc_offsetof(ucc_tl_mlx5_context_config_t, mcast_ctx_conf.timeout), UCC_CONFIG_TYPE_INT}, + {"MCAST_ENABLE", "0", "Enable Mcast", + ucc_offsetof(ucc_tl_mlx5_context_config_t, mcast_ctx_conf.mcast_enabled), + UCC_CONFIG_TYPE_INT}, + {"MCAST_NET_DEVICE", "", "Specifies which network device to use for Mcast", ucc_offsetof(ucc_tl_mlx5_context_config_t, mcast_ctx_conf.ib_dev_name), UCC_CONFIG_TYPE_STRING}, diff --git a/src/components/tl/mlx5/tl_mlx5.h b/src/components/tl/mlx5/tl_mlx5.h index 155e6144af..8dbe4ff408 100644 --- a/src/components/tl/mlx5/tl_mlx5.h +++ b/src/components/tl/mlx5/tl_mlx5.h @@ -106,27 +106,41 @@ typedef enum TL_MLX5_TEAM_STATE_INIT, TL_MLX5_TEAM_STATE_POSTED, TL_MLX5_TEAM_STATE_ALLTOALL_INIT, - TL_MLX5_TEAM_STATE_ALLTOALL_POSTED + TL_MLX5_TEAM_STATE_ALLTOALL_POSTED, + TL_MLX5_TEAM_STATE_ALLTOALL_READY, + TL_MLX5_TEAM_STATE_ALLTOALL_NOT_AVAILABLE } ucc_tl_mlx5_team_state_t; +typedef enum +{ + TL_MLX5_TEAM_STATE_MCAST_INIT, + TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_POST, + TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_READY, + TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_FAILED, + TL_MLX5_TEAM_STATE_MCAST_GRP_BCAST_POST, + TL_MLX5_TEAM_STATE_MCAST_READY, + TL_MLX5_TEAM_STATE_MCAST_NOT_AVAILABLE +} ucc_tl_mlx5_team_mcast_state_t; + typedef struct ucc_tl_mlx5_team_status { ucc_status_t local; ucc_status_t global; } ucc_tl_mlx5_team_status_t; typedef struct ucc_tl_mlx5_team { - ucc_tl_team_t super; - ucc_service_coll_req_t *scoll_req; - ucc_tl_mlx5_team_state_t state; - void *dm_offset; - ucc_mpool_t dm_pool; - struct ibv_dm *dm_ptr; - struct ibv_mr *dm_mr; - ucc_tl_mlx5_team_status_t a2a_status; - ucc_tl_mlx5_alltoall_t *a2a; - ucc_topo_t *topo; - ucc_ep_map_t ctx_map; - ucc_tl_mlx5_mcast_team_t *mcast; + ucc_tl_team_t super; + ucc_service_coll_req_t *scoll_req; + ucc_tl_mlx5_team_state_t a2a_state; + ucc_tl_mlx5_team_mcast_state_t mcast_state; + void *dm_offset; + ucc_mpool_t dm_pool; + struct ibv_dm *dm_ptr; + struct ibv_mr *dm_mr; + ucc_tl_mlx5_team_status_t a2a_status; + ucc_tl_mlx5_alltoall_t *a2a; + ucc_topo_t *topo; + ucc_ep_map_t ctx_map; + ucc_tl_mlx5_mcast_team_t *mcast; } ucc_tl_mlx5_team_t; UCC_CLASS_DECLARE(ucc_tl_mlx5_team_t, ucc_base_context_t *, const ucc_base_team_params_t *); diff --git a/src/components/tl/mlx5/tl_mlx5_context.c b/src/components/tl/mlx5/tl_mlx5_context.c index 77d02acc04..1bf21d50c0 100644 --- a/src/components/tl/mlx5/tl_mlx5_context.c +++ b/src/components/tl/mlx5/tl_mlx5_context.c @@ -49,6 +49,13 @@ UCC_CLASS_INIT_FUNC(ucc_tl_mlx5_context_t, goto err_rcache; } + status = ucc_tl_mlx5_mcast_context_init(&(self->mcast), &(self->cfg.mcast_ctx_conf)); + if (UCC_OK != status) { + self->mcast.mcast_ready = 0; + tl_debug(self->super.super.lib, "failed to initialize mcast context"); + } else { + self->mcast.mcast_ready = 1; + } return UCC_OK; err_rcache: @@ -72,6 +79,10 @@ UCC_CLASS_CLEANUP_FUNC(ucc_tl_mlx5_context_t) } ucc_mpool_cleanup(&self->req_mp, 1); + + if (self->mcast.mcast_ready) { + ucc_tl_mlx5_mcast_clean_ctx(&self->mcast.mcast_context); + } } UCC_CLASS_DEFINE(ucc_tl_mlx5_context_t, ucc_tl_context_t); diff --git a/src/components/tl/mlx5/tl_mlx5_team.c b/src/components/tl/mlx5/tl_mlx5_team.c index 939f29a3d4..16c85b54b9 100644 --- a/src/components/tl/mlx5/tl_mlx5_team.c +++ b/src/components/tl/mlx5/tl_mlx5_team.c @@ -11,6 +11,8 @@ #include "alltoall/alltoall.h" #include "core/ucc_team.h" #include +#include "mcast/tl_mlx5_mcast.h" +#include "mcast/tl_mlx5_mcast_helper.h" static ucc_status_t ucc_tl_mlx5_topo_init(ucc_tl_mlx5_team_t *team) { @@ -65,12 +67,22 @@ UCC_CLASS_INIT_FUNC(ucc_tl_mlx5_team_t, ucc_base_context_t *tl_context, } self->a2a = NULL; - status = ucc_tl_mlx5_team_init_alltoall(self); + status = ucc_tl_mlx5_team_init_alltoall(self); if (UCC_OK != status) { return status; } - self->state = TL_MLX5_TEAM_STATE_INIT; + self->mcast = NULL; + status = ucc_tl_mlx5_mcast_team_init(tl_context, &(self->mcast), &(ctx->mcast), params, + &(UCC_TL_MLX5_TEAM_LIB(self)->cfg.mcast_conf)); + if (UCC_OK != status) { + self->mcast_state = TL_MLX5_TEAM_STATE_MCAST_NOT_AVAILABLE; + } else { + self->mcast_state = TL_MLX5_TEAM_STATE_MCAST_INIT; + } + + self->a2a_state = TL_MLX5_TEAM_STATE_INIT; + tl_debug(tl_context->lib, "posted tl team: %p", self); return UCC_OK; } @@ -82,6 +94,9 @@ UCC_CLASS_CLEANUP_FUNC(ucc_tl_mlx5_team_t) ucc_tl_mlx5_dm_cleanup(self); ucc_tl_mlx5_alltoall_cleanup(self); ucc_tl_mlx5_topo_cleanup(self); + if (self->mcast_state != TL_MLX5_TEAM_STATE_MCAST_NOT_AVAILABLE) { + ucc_tl_mlx5_clean_mcast_comm(self->mcast->mcast_comm); + } } UCC_CLASS_DEFINE_DELETE_FUNC(ucc_tl_mlx5_team_t, ucc_base_team_t); @@ -93,15 +108,16 @@ ucc_status_t ucc_tl_mlx5_team_destroy(ucc_base_team_t *tl_team) return UCC_OK; } -ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team) +static inline ucc_status_t ucc_tl_mlx5_a2a_team_test(ucc_base_team_t *team) { ucc_tl_mlx5_team_t *tl_team = ucc_derived_of(team, ucc_tl_mlx5_team_t); ucc_team_t *core_team = UCC_TL_CORE_TEAM(tl_team); ucc_subset_t subset = {.map = UCC_TL_TEAM_MAP(tl_team), .myrank = UCC_TL_TEAM_RANK(tl_team)}; + ucc_status_t status = UCC_OK; - switch (tl_team->state) { + switch (tl_team->a2a_state) { case TL_MLX5_TEAM_STATE_INIT: status = ucc_service_allreduce( core_team, &tl_team->a2a_status.local, &tl_team->a2a_status.global, @@ -111,7 +127,7 @@ ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team) "failed to collect global status"); return status; } - tl_team->state = TL_MLX5_TEAM_STATE_POSTED; + tl_team->a2a_state = TL_MLX5_TEAM_STATE_POSTED; case TL_MLX5_TEAM_STATE_POSTED: status = ucc_service_coll_test(tl_team->scoll_req); if (status < 0) { @@ -124,11 +140,11 @@ ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team) return status; } ucc_service_coll_finalize(tl_team->scoll_req); - tl_team->state = TL_MLX5_TEAM_STATE_ALLTOALL_INIT; + tl_team->a2a_state = TL_MLX5_TEAM_STATE_ALLTOALL_INIT; case TL_MLX5_TEAM_STATE_ALLTOALL_INIT: tl_team->a2a_status.local = ucc_tl_mlx5_team_test_alltoall_start(tl_team); - tl_team->state = TL_MLX5_TEAM_STATE_ALLTOALL_POSTED; + tl_team->a2a_state = TL_MLX5_TEAM_STATE_ALLTOALL_POSTED; case TL_MLX5_TEAM_STATE_ALLTOALL_POSTED: // coverity[deref_arg:FALSE] tl_team->a2a_status.local = @@ -140,9 +156,52 @@ ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team) tl_debug(UCC_TL_TEAM_LIB(tl_team), "failed to init a2a: %s", ucc_status_string(tl_team->a2a_status.local)); } + tl_team->a2a_state = TL_MLX5_TEAM_STATE_ALLTOALL_READY; + tl_debug(team->context->lib, "initialized tl a2a team: %p", tl_team); + case TL_MLX5_TEAM_STATE_ALLTOALL_READY: + case TL_MLX5_TEAM_STATE_ALLTOALL_NOT_AVAILABLE: + return UCC_OK; + default: + tl_error(team->context->lib, "unknown state during a2a team: %p create", tl_team); + return UCC_ERR_NO_RESOURCE; + } +} + +ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team) +{ + ucc_tl_mlx5_team_t *tl_team = ucc_derived_of(team, ucc_tl_mlx5_team_t); + ucc_status_t a2a_status = UCC_OK; + ucc_status_t mcast_status = UCC_OK; + + a2a_status = ucc_tl_mlx5_a2a_team_test(team); + if (a2a_status < 0) { + tl_error(team->context->lib, "ALLTOALL tl team: %p creation failed %d", + team, a2a_status); + tl_team->a2a_state = TL_MLX5_TEAM_STATE_ALLTOALL_NOT_AVAILABLE; + } + + if (tl_team->mcast_state != TL_MLX5_TEAM_STATE_MCAST_NOT_AVAILABLE) { + mcast_status = ucc_tl_mlx5_mcast_team_test(team); + if (mcast_status < 0) { + tl_error(team->context->lib, "MCAST tl team: %p creation failed %d", + team, mcast_status); + tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_NOT_AVAILABLE; + } + } + + if (tl_team->mcast_state == TL_MLX5_TEAM_STATE_MCAST_NOT_AVAILABLE && + tl_team->a2a_state == TL_MLX5_TEAM_STATE_ALLTOALL_NOT_AVAILABLE) { + tl_error(team->context->lib, "unable to initialize tl team: %p", team); + return UCC_ERR_NO_RESOURCE; + } + + if (UCC_OK != a2a_status || UCC_OK != mcast_status) { + return UCC_INPROGRESS; } - tl_debug(team->context->lib, "initialized tl team: %p", tl_team); + tl_debug(team->context->lib, "initialized tl team: %p: MCAST component is %s ALLTOALL component is %s", + team, (tl_team->mcast_state == TL_MLX5_TEAM_STATE_MCAST_READY)?"ENABLED":"DISABLED", + (tl_team->a2a_state == TL_MLX5_TEAM_STATE_ALLTOALL_READY)?"ENABLED":"DISABLED"); return UCC_OK; }