Skip to content

Commit

Permalink
TL/SHARP: Add reduce-scatter support
Browse files Browse the repository at this point in the history
  • Loading branch information
bureddy authored and Sergei-Lebedev committed Jan 13, 2024
1 parent 974605a commit b9da6fe
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 4 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/clang-tidy-nvidia.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ on: [push, pull_request]
env:
OPEN_UCX_LINK: https://github.com/openucx/ucx
OPEN_UCX_BRANCH: master
HPCX_LINK: http://content.mellanox.com/hpc/hpc-x/v2.13/hpcx-v2.13-gcc-MLNX_OFED_LINUX-5-ubuntu20.04-cuda11-gdrcopy2-nccl2.12-x86_64.tbz
HPCX_LINK: https://content.mellanox.com/hpc/hpc-x/v2.17.1rc2/hpcx-v2.17.1-gcc-mlnx_ofed-ubuntu20.04-cuda12-x86_64.tbz
CLANG_VER: 12
MLNX_OFED_VER: 5.9-0.5.6.0
CUDA_VER: 11-4
Expand Down Expand Up @@ -45,8 +45,8 @@ jobs:
run: |
cd /tmp
wget ${HPCX_LINK}
tar xjf hpcx-v2.13-gcc-MLNX_OFED_LINUX-5-ubuntu20.04-cuda11-gdrcopy2-nccl2.12-x86_64.tbz
mv hpcx-v2.13-gcc-MLNX_OFED_LINUX-5-ubuntu20.04-cuda11-gdrcopy2-nccl2.12-x86_64 hpcx
tar xjf hpcx-v2.17.1-gcc-mlnx_ofed-ubuntu20.04-cuda12-x86_64.tbz
mv hpcx-v2.17.1-gcc-mlnx_ofed-ubuntu20.04-cuda12-x86_64 hpcx
- uses: actions/checkout@v1
- name: Build UCC
run: |
Expand Down
1 change: 1 addition & 0 deletions config/m4/sharp.m4
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ AS_IF([test "x$with_sharp" != "xno"],
AC_SUBST(SHARP_LDFLAGS, "-lsharp_coll -L$check_sharp_dir/lib")
AC_CHECK_DECLS([SHARP_COLL_HIDE_ERRORS], [], [], [[#include <sharp/api/sharp_coll.h>]])
AC_CHECK_DECLS([SHARP_COLL_DISABLE_LAZY_GROUP_RESOURCE_ALLOC], [], [], [[#include <sharp/api/sharp_coll.h>]])
AC_CHECK_DECLS([sharp_coll_do_reduce_scatter], [], [], [[#include <sharp/api/sharp_coll.h>]])
],
[
AS_IF([test "x$with_sharp" != "xguess"],
Expand Down
13 changes: 12 additions & 1 deletion src/components/tl/sharp/tl_sharp.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ typedef struct ucc_tl_sharp_task {
ucc_tl_sharp_reg_t *s_mem_h;
ucc_tl_sharp_reg_t *r_mem_h;
} allreduce;
struct {
ucc_tl_sharp_reg_t *s_mem_h;
ucc_tl_sharp_reg_t *r_mem_h;
} reduce_scatter;
struct {
ucc_tl_sharp_reg_t *mem_h;
} bcast;
Expand All @@ -131,9 +135,16 @@ ucc_status_t sharp_status_to_ucc_status(int status);
(ucc_derived_of((_task)->super.team->context->lib, ucc_tl_sharp_lib_t))
#define TASK_ARGS(_task) (_task)->super.bargs.args

#define UCC_TL_SHARP_SUPPORTED_COLLS \
#define UCC_TL_BASIC_SHARP_SUPPORTED_COLLS \
(UCC_COLL_TYPE_ALLREDUCE | UCC_COLL_TYPE_BARRIER | UCC_COLL_TYPE_BCAST)

#if HAVE_DECL_SHARP_COLL_DO_REDUCE_SCATTER
#define UCC_TL_SHARP_SUPPORTED_COLLS \
(UCC_TL_BASIC_SHARP_SUPPORTED_COLLS | UCC_COLL_TYPE_REDUCE_SCATTER)
#else
#define UCC_TL_SHARP_SUPPORTED_COLLS (UCC_TL_BASIC_SHARP_SUPPORTED_COLLS)
#endif

UCC_CLASS_DECLARE(ucc_tl_sharp_team_t, ucc_base_context_t *,
const ucc_base_team_params_t *);

Expand Down
94 changes: 94 additions & 0 deletions src/components/tl/sharp/tl_sharp_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,100 @@ ucc_status_t ucc_tl_sharp_bcast_start(ucc_coll_task_t *coll_task)
return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
}

#if HAVE_DECL_SHARP_COLL_DO_REDUCE_SCATTER
ucc_status_t ucc_tl_sharp_reduce_scatter_start(ucc_coll_task_t *coll_task)
{
ucc_tl_sharp_task_t *task = ucc_derived_of(coll_task, ucc_tl_sharp_task_t);
ucc_tl_sharp_team_t *team = TASK_TEAM(task);
ucc_coll_args_t *args = &TASK_ARGS(task);
size_t count = args->dst.info.count;
ucc_datatype_t dt = args->dst.info.datatype;
struct sharp_coll_reduce_spec reduce_spec;
enum sharp_datatype sharp_type;
enum sharp_reduce_op op_type;
size_t src_data_size, dst_data_size;
int ret;

UCC_TL_SHARP_PROFILE_REQUEST_EVENT(coll_task, "sharp_reduce_scatter_start",
0);

sharp_type = ucc_to_sharp_dtype[UCC_DT_PREDEFINED_ID(dt)];
op_type = ucc_to_sharp_reduce_op[args->op];
src_data_size = ucc_dt_size(dt) * count * UCC_TL_TEAM_SIZE(team);
dst_data_size = ucc_dt_size(dt) * count;

if (!UCC_IS_INPLACE(*args)) {
ucc_tl_sharp_mem_register(TASK_CTX(task), team, args->src.info.buffer,
src_data_size, &task->reduce_scatter.s_mem_h);
}
ucc_tl_sharp_mem_register(TASK_CTX(task), team, args->dst.info.buffer,
dst_data_size, &task->reduce_scatter.r_mem_h);

if (!UCC_IS_INPLACE(*args)) {
reduce_spec.sbuf_desc.buffer.ptr = args->src.info.buffer;
reduce_spec.sbuf_desc.buffer.mem_handle =
task->reduce_scatter.s_mem_h->mr;
reduce_spec.sbuf_desc.mem_type =
ucc_to_sharp_memtype[args->src.info.mem_type];
} else {
reduce_spec.sbuf_desc.buffer.ptr = args->dst.info.buffer;
reduce_spec.sbuf_desc.buffer.mem_handle =
task->reduce_scatter.r_mem_h->mr;
reduce_spec.sbuf_desc.mem_type =
ucc_to_sharp_memtype[args->dst.info.mem_type];
}

reduce_spec.sbuf_desc.buffer.length = src_data_size;
reduce_spec.sbuf_desc.type = SHARP_DATA_BUFFER;
reduce_spec.rbuf_desc.buffer.ptr = args->dst.info.buffer;
reduce_spec.rbuf_desc.buffer.length = dst_data_size;
reduce_spec.rbuf_desc.buffer.mem_handle = task->reduce_scatter.r_mem_h->mr;
reduce_spec.rbuf_desc.type = SHARP_DATA_BUFFER;
reduce_spec.rbuf_desc.mem_type =
ucc_to_sharp_memtype[args->dst.info.mem_type];
reduce_spec.aggr_mode = SHARP_AGGREGATION_NONE;
reduce_spec.length = count;
reduce_spec.dtype = sharp_type;
reduce_spec.op = op_type;
reduce_spec.offset = 0;

ret = sharp_coll_do_reduce_scatter_nb(team->sharp_comm, &reduce_spec,
&task->req_handle);
if (ret != SHARP_COLL_SUCCESS) {
tl_error(UCC_TASK_LIB(task),
"sharp_coll_do_reduce_scatter_nb failed:%s",
sharp_coll_strerror(ret));
coll_task->status = ucc_tl_sharp_status_to_ucc(ret);
return ucc_task_complete(coll_task);
}
coll_task->status = UCC_INPROGRESS;

return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
}

ucc_status_t ucc_tl_sharp_reduce_scatter_init(ucc_tl_sharp_task_t *task)
{
ucc_coll_args_t *args = &TASK_ARGS(task);

if (!ucc_coll_args_is_predefined_dt(args, UCC_RANK_INVALID)) {
return UCC_ERR_NOT_SUPPORTED;
}

if ((!UCC_IS_INPLACE(*args) &&
ucc_to_sharp_memtype[args->src.info.mem_type] == SHARP_MEM_TYPE_LAST) ||
ucc_to_sharp_memtype[args->dst.info.mem_type] == SHARP_MEM_TYPE_LAST ||
ucc_to_sharp_dtype[UCC_DT_PREDEFINED_ID(args->dst.info.datatype)] ==
SHARP_DTYPE_NULL ||
ucc_to_sharp_reduce_op[args->op] == SHARP_OP_NULL) {
return UCC_ERR_NOT_SUPPORTED;
}

task->super.post = ucc_tl_sharp_reduce_scatter_start;
task->super.progress = ucc_tl_sharp_collective_progress;
return UCC_OK;
};
#endif

ucc_status_t ucc_tl_sharp_allreduce_init(ucc_tl_sharp_task_t *task)
{
ucc_coll_args_t *args = &TASK_ARGS(task);
Expand Down
3 changes: 3 additions & 0 deletions src/components/tl/sharp/tl_sharp_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,7 @@ ucc_status_t ucc_tl_sharp_barrier_init(ucc_tl_sharp_task_t *task);

ucc_status_t ucc_tl_sharp_bcast_init(ucc_tl_sharp_task_t *task);

#if HAVE_DECL_SHARP_COLL_DO_REDUCE_SCATTER
ucc_status_t ucc_tl_sharp_reduce_scatter_init(ucc_tl_sharp_task_t *task);
#endif
#endif
5 changes: 5 additions & 0 deletions src/components/tl/sharp/tl_sharp_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,11 @@ ucc_status_t ucc_tl_sharp_coll_init(ucc_base_coll_args_t *coll_args,
case UCC_COLL_TYPE_BCAST:
status = ucc_tl_sharp_bcast_init(task);
break;
#if HAVE_DECL_SHARP_COLL_DO_REDUCE_SCATTER
case UCC_COLL_TYPE_REDUCE_SCATTER:
status = ucc_tl_sharp_reduce_scatter_init(task);
break;
#endif
default:
tl_debug(UCC_TASK_LIB(task),
"collective %d is not supported by sharp tl",
Expand Down

0 comments on commit b9da6fe

Please sign in to comment.