diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h index 1d08f99edf..28626b7d13 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h @@ -131,6 +131,7 @@ typedef struct ucc_tl_mlx5_mcast_coll_context { struct ibv_pd *pd; char *devname; int max_qp_wr; + int user_provided_ib; int ib_port; int pkey_index; int mtu; diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_context.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_context.c index 192000ee86..a6ac59361f 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_context.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_context.c @@ -32,12 +32,11 @@ ucc_status_t ucc_tl_mlx5_mcast_context_init(ucc_tl_mlx5_mcast_context_t *cont int is_ipv4 = 0; struct sockaddr_in *in_src_addr = NULL; struct rdma_cm_event *revent = NULL; - char *ib = NULL; - char *ib_name = NULL; - char *port = NULL; int active_mtu = 4096; int max_mtu = 4096; ucc_tl_mlx5_mcast_coll_context_t *ctx = NULL; + char *ib_devname = NULL; + int devname_len = 0, ib_port = 1; struct ibv_port_attr port_attr; struct ibv_device_attr device_attr; struct sockaddr_storage ip_oib_addr; @@ -47,9 +46,9 @@ ucc_status_t ucc_tl_mlx5_mcast_context_init(ucc_tl_mlx5_mcast_context_t *cont ucc_tl_mlx5_context_t *mlx5_ctx; ucc_base_lib_t *lib; int i; - int user_provided_ib; int ib_valid; const char *dst; + char tmp[128], *pos, *end_pos; mlx5_ctx = ucc_container_of(context, ucc_tl_mlx5_context_t, mcast); lib = mlx5_ctx->super.super.lib; @@ -86,12 +85,34 @@ ucc_status_t ucc_tl_mlx5_mcast_context_init(ucc_tl_mlx5_mcast_context_t *cont memset(ctx->devname, 0, strlen(devname)+3); memcpy(ctx->devname, devname, strlen(devname)); strncat(ctx->devname, ":1", 3); - user_provided_ib = 0; + ctx->user_provided_ib = 0; + ctx->ib_port = 1; } else { ib_valid = 0; /* user has provided the devname now make sure it is valid */ + /* check if port number is also included and extract devname from user str */ + ib_devname = mcast_ctx_conf->ib_dev_name; + pos = strstr(ib_devname, ":"); + if (!pos) { + devname_len = sizeof(tmp) - 1; + } else { + devname_len = (int)(pos - ib_devname); + pos++; + errno = 0; + ib_port = (int)strtol(pos, &end_pos, 0); + if (errno != 0 || pos == end_pos || strcmp(end_pos,"\0") || ib_port < 0 + || ib_port > UINT8_MAX ) { + tl_warn(lib, "wrong device's port number"); + return UCC_ERR_INVALID_PARAM; + } + } + ctx->ib_port = ib_port; + strncpy(tmp, ib_devname, devname_len); + tmp[devname_len] = '\0'; + ib_devname = tmp; + for (i = 0; device_list[i]; ++i) { - if (!strcmp(ibv_get_device_name(device_list[i]), mcast_ctx_conf->ib_dev_name)) { + if (!strcmp(ibv_get_device_name(device_list[i]), ib_devname)) { ib_valid = 1; break; } @@ -102,8 +123,8 @@ ucc_status_t ucc_tl_mlx5_mcast_context_init(ucc_tl_mlx5_mcast_context_t *cont ibv_free_device_list(device_list); goto error; } - ctx->devname = mcast_ctx_conf->ib_dev_name; - user_provided_ib = 1; + ctx->devname = mcast_ctx_conf->ib_dev_name; + ctx->user_provided_ib = 1; } ibv_free_device_list(device_list); @@ -111,7 +132,7 @@ ucc_status_t ucc_tl_mlx5_mcast_context_init(ucc_tl_mlx5_mcast_context_t *cont status = ucc_tl_mlx5_probe_ip_over_ib(ctx->devname, &ip_oib_addr); if (UCC_OK != status) { tl_debug(lib, "failed to get ipoib interface for devname %s", ctx->devname); - if (!user_provided_ib) { + if (!ctx->user_provided_ib) { ucc_free(ctx->devname); } goto error; @@ -179,11 +200,6 @@ ucc_status_t ucc_tl_mlx5_mcast_context_init(ucc_tl_mlx5_mcast_context_t *cont goto error; } - ib = strdup(ctx->devname); - ucc_string_split(ib, ":", 2, &ib_name, &port); - ctx->ib_port = atoi(port); - ucc_free(ib); - /* Determine MTU */ if (ibv_query_port(ctx->ctx, ctx->ib_port, &port_attr)) { tl_error(lib, "couldn't query port in ctx create, errno %d", errno); @@ -191,7 +207,6 @@ ucc_status_t ucc_tl_mlx5_mcast_context_init(ucc_tl_mlx5_mcast_context_t *cont goto error; } - for (i = 0; i < UCC_TL_MLX5_MCAST_MAX_MTU_COUNT; i++) { if (mtu_lookup[i][1] == port_attr.max_mtu) { max_mtu = mtu_lookup[i][0]; @@ -292,7 +307,7 @@ ucc_status_t ucc_tl_mlx5_mcast_clean_ctx(ucc_tl_mlx5_mcast_coll_context_t *ctx) ctx->channel = NULL; } - if (ctx->devname && !strcmp(ctx->params.ib_dev_name, "")) { + if (ctx->devname && !ctx->user_provided_ib) { ucc_free(ctx->devname); ctx->devname = NULL; } diff --git a/src/components/tl/mlx5/tl_mlx5.h b/src/components/tl/mlx5/tl_mlx5.h index 8dbe4ff408..6524e7f162 100644 --- a/src/components/tl/mlx5/tl_mlx5.h +++ b/src/components/tl/mlx5/tl_mlx5.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ diff --git a/src/components/tl/mlx5/tl_mlx5_coll.c b/src/components/tl/mlx5/tl_mlx5_coll.c index 861d4a4c67..e918be166e 100644 --- a/src/components/tl/mlx5/tl_mlx5_coll.c +++ b/src/components/tl/mlx5/tl_mlx5_coll.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -16,6 +16,7 @@ ucc_status_t ucc_tl_mlx5_bcast_mcast_init(ucc_base_coll_args_t *coll_args, ucc_tl_mlx5_task_t *task = NULL; if (UCC_COLL_ARGS_ACTIVE_SET(&coll_args->args)) { + tl_trace(team->context->lib, "mcast bcast not supported for active sets"); return UCC_ERR_NOT_SUPPORTED; } @@ -51,6 +52,7 @@ ucc_status_t ucc_tl_mlx5_task_finalize(ucc_coll_task_t *coll_task) if (req != NULL) { ucc_assert(coll_task->status != UCC_INPROGRESS); ucc_free(req); + tl_trace(UCC_TASK_LIB(task), "finalizing an mcast task %p", task); task->bcast_mcast.req_handle = NULL; } diff --git a/src/components/tl/mlx5/tl_mlx5_team.c b/src/components/tl/mlx5/tl_mlx5_team.c index 16c85b54b9..4185e6302b 100644 --- a/src/components/tl/mlx5/tl_mlx5_team.c +++ b/src/components/tl/mlx5/tl_mlx5_team.c @@ -222,8 +222,8 @@ ucc_status_t ucc_tl_mlx5_team_get_scores(ucc_base_team_t * tl_team, team_info.num_mem_types = 2; team_info.supported_mem_types = mt; team_info.supported_colls = - (UCC_COLL_TYPE_ALLTOALL * (team->a2a_status.local == UCC_OK)) | - UCC_COLL_TYPE_BCAST; + (UCC_COLL_TYPE_ALLTOALL * (team->a2a_state == TL_MLX5_TEAM_STATE_ALLTOALL_READY)) | + UCC_COLL_TYPE_BCAST * (team->mcast_state == TL_MLX5_TEAM_STATE_MCAST_READY); team_info.size = UCC_TL_TEAM_SIZE(team); status = ucc_coll_score_build_default(