Skip to content

Commit

Permalink
CL/HIER: fix int overflow in alltoall (#944)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergei-Lebedev authored Apr 4, 2024
1 parent cd018cb commit 7cb8058
Showing 1 changed file with 31 additions and 28 deletions.
59 changes: 31 additions & 28 deletions src/components/cl/hier/alltoall/alltoall.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand All @@ -13,7 +13,7 @@ ucc_base_coll_alg_info_t
{.id = UCC_CL_HIER_ALLTOALL_ALG_NODE_SPLIT,
.name = "node_split",
.desc = "splitting alltoall into two concurrent a2av calls"
" withing the node and outside of it"},
" within the node and outside of it"},
[UCC_CL_HIER_ALLTOALL_ALG_LAST] = {
.id = 0, .name = NULL, .desc = NULL}};

Expand All @@ -23,10 +23,12 @@ UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_alltoall_init,
ucc_coll_task_t **task)
{
ucc_cl_hier_team_t *cl_team = ucc_derived_of(team, ucc_cl_hier_team_t);
ucc_rank_t tsize = UCC_CL_TEAM_SIZE(cl_team);
uint64_t *counts, *displs;
ucc_status_t status;
ucc_base_coll_args_t args;
int i, count;
ucc_rank_t team_size;
int i;
size_t count;
ucc_mc_buffer_header_t *h;

if (UCC_IS_INPLACE(coll_args->args)) {
Expand All @@ -45,42 +47,43 @@ UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_alltoall_init,
args.args.mask = UCC_COLL_ARGS_FIELD_FLAGS;
args.args.flags = 0;
}
args.args.flags |= UCC_COLL_ARGS_FLAG_CONTIG_SRC_BUFFER |
UCC_COLL_ARGS_FLAG_CONTIG_DST_BUFFER;
team_size = UCC_CL_TEAM_SIZE(cl_team);

status =
ucc_mc_alloc(&h, sizeof(int) * team_size * 2, UCC_MEMORY_TYPE_HOST);
args.args.flags |= UCC_COLL_ARGS_FLAG_CONTIG_SRC_BUFFER |
UCC_COLL_ARGS_FLAG_CONTIG_DST_BUFFER |
UCC_COLL_ARGS_FLAG_COUNT_64BIT |
UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT;
status = ucc_mc_alloc(&h, sizeof(uint64_t) * tsize * 2,
UCC_MEMORY_TYPE_HOST);
if (ucc_unlikely(UCC_OK != status)) {
cl_error(team->context->lib,
"failed to allocate %zd bytes for full counts",
sizeof(int) * team_size * 2);
sizeof(uint64_t) * tsize * 2);
return status;
}

args.args.src.info_v.buffer = coll_args->args.src.info.buffer;
args.args.dst.info_v.buffer = coll_args->args.dst.info.buffer;
args.args.src.info_v.datatype = coll_args->args.src.info.datatype;
args.args.dst.info_v.datatype = coll_args->args.dst.info.datatype;
args.args.src.info_v.mem_type = coll_args->args.src.info.mem_type;
args.args.dst.info_v.mem_type = coll_args->args.dst.info.mem_type;

args.args.src.info_v.counts = h->addr;
args.args.src.info_v.displacements =
PTR_OFFSET(h->addr, sizeof(int) * team_size);

args.args.src.info_v.buffer = coll_args->args.src.info.buffer;
args.args.dst.info_v.buffer = coll_args->args.dst.info.buffer;
args.args.src.info_v.datatype = coll_args->args.src.info.datatype;
args.args.dst.info_v.datatype = coll_args->args.dst.info.datatype;
args.args.src.info_v.mem_type = coll_args->args.src.info.mem_type;
args.args.dst.info_v.mem_type = coll_args->args.dst.info.mem_type;
args.args.src.info_v.counts = h->addr;
args.args.src.info_v.displacements = PTR_OFFSET(h->addr,
sizeof(uint64_t) * tsize);
args.args.dst.info_v.counts = args.args.src.info_v.counts;
args.args.dst.info_v.displacements = args.args.src.info_v.displacements;

count = (int)coll_args->args.src.info.count / team_size;
((int *)args.args.src.info_v.counts)[0] = count;
((int *)args.args.src.info_v.displacements)[0] = 0;
counts = (uint64_t *)args.args.src.info_v.counts;
displs = (uint64_t *)args.args.src.info_v.displacements;

for (i = 1; i < team_size; i++) {
((int *)args.args.src.info_v.counts)[i] = count;
((int *)args.args.src.info_v.displacements)[i] =
((int *)args.args.src.info_v.displacements)[i - 1] + count;
count = coll_args->args.src.info.count / tsize;
counts[0] = count;
displs[0] = 0;
for (i = 1; i < tsize; i++) {
counts[i] = count;
displs[i] = displs[i - 1] + count;
}

status = ucc_cl_hier_alltoallv_init(&args, team, task);
if (UCC_OK != status) {
cl_error(team->context->lib, "failed to init split node a2av task");
Expand Down

0 comments on commit 7cb8058

Please sign in to comment.