From 7cb8058a61414a22ac886a2c99bb4a3bad1151cb Mon Sep 17 00:00:00 2001 From: Sergey Lebedev Date: Thu, 4 Apr 2024 18:16:01 +0200 Subject: [PATCH] CL/HIER: fix int overflow in alltoall (#944) --- src/components/cl/hier/alltoall/alltoall.c | 59 ++++++++++++---------- 1 file changed, 31 insertions(+), 28 deletions(-) diff --git a/src/components/cl/hier/alltoall/alltoall.c b/src/components/cl/hier/alltoall/alltoall.c index 0048474c98..f2f602e6aa 100644 --- a/src/components/cl/hier/alltoall/alltoall.c +++ b/src/components/cl/hier/alltoall/alltoall.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -13,7 +13,7 @@ ucc_base_coll_alg_info_t {.id = UCC_CL_HIER_ALLTOALL_ALG_NODE_SPLIT, .name = "node_split", .desc = "splitting alltoall into two concurrent a2av calls" - " withing the node and outside of it"}, + " within the node and outside of it"}, [UCC_CL_HIER_ALLTOALL_ALG_LAST] = { .id = 0, .name = NULL, .desc = NULL}}; @@ -23,10 +23,12 @@ UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_alltoall_init, ucc_coll_task_t **task) { ucc_cl_hier_team_t *cl_team = ucc_derived_of(team, ucc_cl_hier_team_t); + ucc_rank_t tsize = UCC_CL_TEAM_SIZE(cl_team); + uint64_t *counts, *displs; ucc_status_t status; ucc_base_coll_args_t args; - int i, count; - ucc_rank_t team_size; + int i; + size_t count; ucc_mc_buffer_header_t *h; if (UCC_IS_INPLACE(coll_args->args)) { @@ -45,42 +47,43 @@ UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_alltoall_init, args.args.mask = UCC_COLL_ARGS_FIELD_FLAGS; args.args.flags = 0; } - args.args.flags |= UCC_COLL_ARGS_FLAG_CONTIG_SRC_BUFFER | - UCC_COLL_ARGS_FLAG_CONTIG_DST_BUFFER; - team_size = UCC_CL_TEAM_SIZE(cl_team); - status = - ucc_mc_alloc(&h, sizeof(int) * team_size * 2, UCC_MEMORY_TYPE_HOST); + args.args.flags |= UCC_COLL_ARGS_FLAG_CONTIG_SRC_BUFFER | + UCC_COLL_ARGS_FLAG_CONTIG_DST_BUFFER | + UCC_COLL_ARGS_FLAG_COUNT_64BIT | + UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT; + status = ucc_mc_alloc(&h, sizeof(uint64_t) * tsize * 2, + UCC_MEMORY_TYPE_HOST); if (ucc_unlikely(UCC_OK != status)) { cl_error(team->context->lib, "failed to allocate %zd bytes for full counts", - sizeof(int) * team_size * 2); + sizeof(uint64_t) * tsize * 2); return status; } - args.args.src.info_v.buffer = coll_args->args.src.info.buffer; - args.args.dst.info_v.buffer = coll_args->args.dst.info.buffer; - args.args.src.info_v.datatype = coll_args->args.src.info.datatype; - args.args.dst.info_v.datatype = coll_args->args.dst.info.datatype; - args.args.src.info_v.mem_type = coll_args->args.src.info.mem_type; - args.args.dst.info_v.mem_type = coll_args->args.dst.info.mem_type; - - args.args.src.info_v.counts = h->addr; - args.args.src.info_v.displacements = - PTR_OFFSET(h->addr, sizeof(int) * team_size); - + args.args.src.info_v.buffer = coll_args->args.src.info.buffer; + args.args.dst.info_v.buffer = coll_args->args.dst.info.buffer; + args.args.src.info_v.datatype = coll_args->args.src.info.datatype; + args.args.dst.info_v.datatype = coll_args->args.dst.info.datatype; + args.args.src.info_v.mem_type = coll_args->args.src.info.mem_type; + args.args.dst.info_v.mem_type = coll_args->args.dst.info.mem_type; + args.args.src.info_v.counts = h->addr; + args.args.src.info_v.displacements = PTR_OFFSET(h->addr, + sizeof(uint64_t) * tsize); args.args.dst.info_v.counts = args.args.src.info_v.counts; args.args.dst.info_v.displacements = args.args.src.info_v.displacements; - count = (int)coll_args->args.src.info.count / team_size; - ((int *)args.args.src.info_v.counts)[0] = count; - ((int *)args.args.src.info_v.displacements)[0] = 0; + counts = (uint64_t *)args.args.src.info_v.counts; + displs = (uint64_t *)args.args.src.info_v.displacements; - for (i = 1; i < team_size; i++) { - ((int *)args.args.src.info_v.counts)[i] = count; - ((int *)args.args.src.info_v.displacements)[i] = - ((int *)args.args.src.info_v.displacements)[i - 1] + count; + count = coll_args->args.src.info.count / tsize; + counts[0] = count; + displs[0] = 0; + for (i = 1; i < tsize; i++) { + counts[i] = count; + displs[i] = displs[i - 1] + count; } + status = ucc_cl_hier_alltoallv_init(&args, team, task); if (UCC_OK != status) { cl_error(team->context->lib, "failed to init split node a2av task");