From 7953c4646e577bbb8aea0883135f555069c1c8eb Mon Sep 17 00:00:00 2001 From: Nick Sarkauskas Date: Tue, 28 May 2024 13:37:39 -0700 Subject: [PATCH] TL/UCP: Convert sliding window to schedule-based --- src/components/tl/ucp/Makefile.am | 2 + src/components/tl/ucp/allreduce/allreduce.c | 119 +++++++-- src/components/tl/ucp/allreduce/allreduce.h | 8 +- .../ucp/allreduce/allreduce_sliding_window.c | 250 +++++++----------- .../ucp/allreduce/allreduce_sliding_window.h | 25 +- .../allreduce_sliding_window_setup.c | 87 +++--- src/components/tl/ucp/tl_ucp_coll.h | 14 +- src/components/tl/ucp/tl_ucp_dpu_offload.c | 42 +++ src/components/tl/ucp/tl_ucp_dpu_offload.h | 55 ++++ 9 files changed, 353 insertions(+), 249 deletions(-) create mode 100644 src/components/tl/ucp/tl_ucp_dpu_offload.c create mode 100644 src/components/tl/ucp/tl_ucp_dpu_offload.h diff --git a/src/components/tl/ucp/Makefile.am b/src/components/tl/ucp/Makefile.am index 0f10f00c05..b196479893 100644 --- a/src/components/tl/ucp/Makefile.am +++ b/src/components/tl/ucp/Makefile.am @@ -113,6 +113,8 @@ sources = \ tl_ucp_ep.c \ tl_ucp_coll.c \ tl_ucp_service_coll.c \ + tl_ucp_dpu_offload.h \ + tl_ucp_dpu_offload.c \ $(allgather) \ $(allgatherv) \ $(alltoall) \ diff --git a/src/components/tl/ucp/allreduce/allreduce.c b/src/components/tl/ucp/allreduce/allreduce.c index 8467d5ed2e..051705a625 100644 --- a/src/components/tl/ucp/allreduce/allreduce.c +++ b/src/components/tl/ucp/allreduce/allreduce.c @@ -7,6 +7,12 @@ #include "tl_ucp.h" #include "allreduce.h" #include "utils/ucc_coll_utils.h" +#include "tl_ucp_dpu_offload.h" +#include "../allgather/allgather.h" + +ucc_status_t +ucc_tl_ucp_allreduce_sliding_window_alloc_pipe(ucc_base_team_t *team, + ucc_tl_ucp_task_t *task); ucc_base_coll_alg_info_t ucc_tl_ucp_allreduce_algs[UCC_TL_UCP_ALLREDUCE_ALG_LAST + 1] = { @@ -61,34 +67,117 @@ ucc_tl_ucp_allreduce_sliding_window_init(ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, ucc_coll_task_t **task_h) { - ucc_tl_ucp_task_t *task; - ucc_status_t status = UCC_OK; - ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t); + ucc_tl_ucp_allreduce_sw_host_allgather_t *allgather_data; + ucc_tl_ucp_task_t *rdma_task; + ucc_coll_task_t *barrier_task; + ucc_schedule_t *schedule = NULL; + ucc_status_t status = UCC_OK; + ucc_tl_ucp_team_t *tl_team = + ucc_derived_of(team, ucc_tl_ucp_team_t); + size_t allgather_size = + sizeof(ucc_tl_ucp_allreduce_sw_host_allgather_t); + ucc_rank_t size = team->params.size; + ucc_base_coll_args_t bargs = { + .mask = 0, + .args = { + .coll_type = UCC_COLL_TYPE_ALLGATHER, + .mask = 0, + .src.info = {.buffer = NULL, + .count = allgather_size, + .datatype = UCC_DT_UINT8, + .mem_type = UCC_MEMORY_TYPE_HOST}, + .dst.info = {.buffer = NULL, + .count = allgather_size * size, + .datatype = UCC_DT_UINT8, + .mem_type = UCC_MEMORY_TYPE_HOST} + } + }; + ucc_base_coll_args_t barrier_coll_args = { + .team = team->params.team, + .args.coll_type = UCC_COLL_TYPE_BARRIER, + }; - ALLREDUCE_TASK_CHECK(coll_args->args, tl_team); + status = ucc_tl_ucp_get_schedule(tl_team, coll_args, + (ucc_tl_ucp_schedule_t **)&schedule); + if (ucc_unlikely(UCC_OK != status)) { + return status; + } + + *task_h = &schedule->super; + schedule->super.post = ucc_tl_ucp_allreduce_sliding_window_start; + schedule->super.progress = NULL; + schedule->super.finalize = ucc_tl_ucp_allreduce_sliding_window_finalize; + + schedule->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR; - task = ucc_tl_ucp_init_task(coll_args, team); - if (ucc_unlikely(!task)) { + if (UCC_OK != status) { + ucc_error("failed to init executor: %s", ucc_status_string(status)); + } + + rdma_task = ucc_tl_ucp_init_task(coll_args, team); + if (ucc_unlikely(!rdma_task)) { ucc_error("couldnt allocate task"); return UCC_ERR_NO_MEMORY; } - *task_h = &task->super; - task->super.post = ucc_tl_ucp_allreduce_sliding_window_start; - task->super.progress = ucc_tl_ucp_allreduce_sliding_window_progress; - task->super.finalize = ucc_tl_ucp_allreduce_sliding_window_finalize; - status = ucc_tl_ucp_allreduce_sliding_window_task_init(coll_args, team, task); + if (ucc_tl_ucp_allreduce_sliding_window_alloc_pipe(team, rdma_task) != UCC_OK) { + ucc_error("failed to alloc pipe: %s", ucc_status_string(status)); + goto free_rdma_task; + } + + status = ucc_tl_ucp_allreduce_sliding_window_task_init(coll_args, team, + rdma_task); if (status != UCC_OK) { - ucc_tl_ucp_put_task(task); ucc_error("failed to init task: %s", ucc_status_string(status)); + goto out; } - task->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR; + allgather_data = rdma_task->allreduce_sliding_window.allgather_data; + bargs.args.src.info.buffer = allgather_data; + bargs.args.dst.info.buffer = PTR_OFFSET(allgather_data, allgather_size); - if (UCC_OK != status) { - ucc_error("failed to init executor: %s", ucc_status_string(status)); + rdma_task->super.post = ucc_tl_ucp_allreduce_sliding_window_rdma_task_post; + rdma_task->super.progress = ucc_tl_ucp_allreduce_sliding_window_rdma_progress; + rdma_task->super.finalize = ucc_tl_ucp_allreduce_sliding_window_rdma_task_finalize; + + UCC_CHECK_GOTO(ucc_tl_ucp_allgather_ring_init(&bargs, team, + &rdma_task->allreduce_sliding_window.allgather_task), + free_rdma_pipe, status); + + status = ucc_tl_ucp_coll_init(&barrier_coll_args, team, + &barrier_task); + if (status < 0) { + tl_error(team->context->lib, + "failure during sliding window barrier init: %s", + ucc_status_string(status)); + goto free_allgather_task; } + UCC_CHECK_GOTO(ucc_schedule_add_task(schedule, &rdma_task->super), out, status); + UCC_CHECK_GOTO(ucc_event_manager_subscribe(&schedule->super, + UCC_EVENT_SCHEDULE_STARTED, + &rdma_task->super, + ucc_task_start_handler), + free_barrier_task, status); + + UCC_CHECK_GOTO(ucc_schedule_add_task(schedule, barrier_task), out, status); + UCC_CHECK_GOTO(ucc_event_manager_subscribe(&rdma_task->super, + UCC_EVENT_COMPLETED, + barrier_task, + ucc_task_start_handler), + free_barrier_task, status); + + return status; + +free_barrier_task: + ucc_tl_ucp_coll_finalize(barrier_task); +free_allgather_task: + ucc_tl_ucp_coll_finalize(rdma_task->allreduce_sliding_window.allgather_task); +free_rdma_pipe: + ucc_tl_ucp_allreduce_sliding_window_free_pipe(&rdma_task->super); +free_rdma_task: + ucc_tl_ucp_allreduce_sliding_window_free_task(&rdma_task->super); out: + ucc_tl_ucp_put_schedule(schedule); return status; } diff --git a/src/components/tl/ucp/allreduce/allreduce.h b/src/components/tl/ucp/allreduce/allreduce.h index fc895b91d9..a95f39dc13 100644 --- a/src/components/tl/ucp/allreduce/allreduce.h +++ b/src/components/tl/ucp/allreduce/allreduce.h @@ -70,11 +70,17 @@ void ucc_tl_ucp_allreduce_knomial_progress(ucc_coll_task_t *task); ucc_status_t ucc_tl_ucp_allreduce_sliding_window_start(ucc_coll_task_t *coll_task); -void ucc_tl_ucp_allreduce_sliding_window_progress(ucc_coll_task_t *task); +void ucc_tl_ucp_allreduce_sliding_window_rdma_progress(ucc_coll_task_t *task); ucc_status_t ucc_tl_ucp_allreduce_sliding_window_finalize(ucc_coll_task_t *task); +ucc_status_t +ucc_tl_ucp_allreduce_sliding_window_rdma_task_finalize(ucc_coll_task_t *coll_task); + +ucc_status_t +ucc_tl_ucp_allreduce_sliding_window_rdma_task_post(ucc_coll_task_t *coll_task); + ucc_status_t ucc_tl_ucp_allreduce_knomial_finalize(ucc_coll_task_t *task); ucc_status_t ucc_tl_ucp_allreduce_sra_knomial_init(ucc_base_coll_args_t *coll_args, diff --git a/src/components/tl/ucp/allreduce/allreduce_sliding_window.c b/src/components/tl/ucp/allreduce/allreduce_sliding_window.c index 3800d302af..2ec7bb60e1 100644 --- a/src/components/tl/ucp/allreduce/allreduce_sliding_window.c +++ b/src/components/tl/ucp/allreduce/allreduce_sliding_window.c @@ -13,41 +13,6 @@ ucc_status_t ucc_tl_ucp_barrier_knomial_start(ucc_coll_task_t *task); -static ucc_status_t ucc_tl_ucp_allreduce_sliding_window_register( - ucp_context_h ucp_context, ucc_tl_ucp_team_t *tl_team, - struct ucc_tl_ucp_allreduce_sw_export_buf *ebuf, void *packed_memh) -{ - ucs_status_t ucs_status; - ucp_mem_map_params_t params = {0}; - - ebuf->ucp_context = ucp_context; - - params.field_mask = UCP_MEM_MAP_PARAM_FIELD_EXPORTED_MEMH_BUFFER; - params.exported_memh_buffer = packed_memh; - - ucs_status = ucp_mem_map(ucp_context, ¶ms, &ebuf->memh); - if (UCS_OK != ucs_status) { - tl_error(UCC_TL_TEAM_LIB(tl_team), - "import using ucp_mem_map() returned error: %s\n", - ucs_status_string(ucs_status)); - return ucs_status_to_ucc_status(ucs_status); - } - - ucs_status = ucp_rkey_pack(ucp_context, ebuf->memh, &ebuf->packed_key, - &ebuf->packed_key_len); - if (UCS_OK != ucs_status) { - ucs_status_t unmap_status = ucp_mem_unmap(ucp_context, ebuf->memh); - tl_error(UCC_TL_TEAM_LIB(tl_team), - "ucp_rkey_pack() returned error: %s%s\n", - ucs_status_string(ucs_status), - unmap_status == UCS_OK ? "" : - ". While handling this error, unmapping the memh had an error\n"); - return ucs_status_to_ucc_status(ucs_status); - } - - return UCC_OK; -} - static inline void ucc_tl_ucp_allreduce_sliding_window_reset_buf(ucc_tl_ucp_allreduce_sw_buf_t *buf) { @@ -85,79 +50,62 @@ static inline void ucc_tl_ucp_allreduce_sliding_window_reset_pipeline( ucc_status_t ucc_tl_ucp_allreduce_sliding_window_start(ucc_coll_task_t *coll_task) { + + ucc_tl_ucp_allreduce_sw_pipeline_t *pipe; + ucc_tl_ucp_allreduce_sw_host_allgather_t *allgather_data; ucc_base_coll_args_t *coll_args = &coll_task->bargs; - ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, - ucc_tl_ucp_task_t); - ucc_tl_ucp_team_t *team = TASK_TEAM(task); + ucc_schedule_t *schedule = ucc_derived_of(coll_task, + ucc_schedule_t); + ucc_base_team_t *base_team = schedule->super.team; + ucc_tl_ucp_team_t *team = ucc_derived_of(base_team, + ucc_tl_ucp_team_t); ucc_tl_ucp_context_t *tl_ctx = UCC_TL_UCP_TEAM_CTX(team); ucc_rank_t rank = UCC_TL_TEAM_RANK(team); uint32_t count_total = coll_task->bargs.args.dst.info.count; ucc_rank_t size = coll_task->team->params.size; - ucc_datatype_t dtype = TASK_ARGS(task).dst.info.datatype; + ucc_datatype_t dtype = coll_args->args.dst.info.datatype; size_t dt_size = ucc_dt_size(dtype); int inplace = UCC_IS_INPLACE(coll_args->args); ucc_status_t status = UCC_OK; int put_window_size = UCC_TL_UCP_TEAM_LIB(team) ->cfg.allreduce_sliding_window_put_window_size; - ucc_tl_ucp_allreduce_sw_pipeline_t *pipe = - task->allreduce_sliding_window.pipe; - ucc_tl_ucp_allreduce_sw_host_allgather_t *allgather_data = - task->allreduce_sliding_window.allgather_data; - size_t allgather_size = sizeof(ucc_tl_ucp_allreduce_sw_host_allgather_t); ucc_tl_ucp_allreduce_sw_global_work_buf_info_t *gwbi_p = coll_args->args.global_work_buffer; + ucc_tl_ucp_task_t *rdma_task = ucc_derived_of(schedule->tasks[0], + ucc_tl_ucp_task_t); - ucc_base_coll_args_t bargs = { - .mask = 0, - .args = { - .coll_type = UCC_COLL_TYPE_ALLGATHER, - .mask = 0, - .src.info = {.buffer = allgather_data, - .count = allgather_size, - .datatype = UCC_DT_UINT8, - .mem_type = UCC_MEMORY_TYPE_HOST}, - .dst.info = {.buffer = PTR_OFFSET(allgather_data, allgather_size), - .count = allgather_size * size, - .datatype = UCC_DT_UINT8, - .mem_type = UCC_MEMORY_TYPE_HOST} - } - }; + pipe = rdma_task->allreduce_sliding_window.pipe; + allgather_data = rdma_task->allreduce_sliding_window.allgather_data; - // Register the src and dst bufs + // Register the src buf if (!inplace) { status = ucc_tl_ucp_allreduce_sliding_window_register( tl_ctx->worker.ucp_context, team, - task->allreduce_sliding_window.src_ebuf, gwbi_p->packed_src_memh); + rdma_task->allreduce_sliding_window.bufs->src_ebuf, + gwbi_p->packed_src_memh); if (status != UCC_OK) { - tl_error(UCC_TASK_LIB(task), "failed to register src memh: %s", + tl_error(UCC_TASK_LIB(rdma_task), "failed to register src memh: %s", ucc_status_string(status)); goto out; } memcpy(allgather_data->packed_src_key, - task->allreduce_sliding_window.src_ebuf->packed_key, - task->allreduce_sliding_window.src_ebuf->packed_key_len); + rdma_task->allreduce_sliding_window.bufs->src_ebuf->packed_key, + rdma_task->allreduce_sliding_window.bufs->src_ebuf->packed_key_len); } + // Register the dst buf status = ucc_tl_ucp_allreduce_sliding_window_register( tl_ctx->worker.ucp_context, team, - task->allreduce_sliding_window.dst_ebuf, gwbi_p->packed_dst_memh); + rdma_task->allreduce_sliding_window.bufs->dst_ebuf, + gwbi_p->packed_dst_memh); if (status != UCC_OK) { - tl_error(UCC_TASK_LIB(task), "failed to register dst memh: %s", + tl_error(UCC_TASK_LIB(rdma_task), "failed to register dst memh: %s", ucc_status_string(status)); goto out; } memcpy(allgather_data->packed_dst_key, - task->allreduce_sliding_window.dst_ebuf->packed_key, - task->allreduce_sliding_window.dst_ebuf->packed_key_len); - - UCC_CHECK_GOTO(ucc_tl_ucp_allgather_ring_init(&bargs, - &team->super.super, - &task->allreduce_sliding_window.allgather_task), - out, status); - - UCC_CHECK_GOTO(ucc_tl_ucp_allgather_ring_start( - task->allreduce_sliding_window.allgather_task), - out, status); + rdma_task->allreduce_sliding_window.bufs->dst_ebuf->packed_key, + rdma_task->allreduce_sliding_window.bufs->dst_ebuf->packed_key_len); if (put_window_size <= 0) put_window_size = size; @@ -171,27 +119,70 @@ ucc_tl_ucp_allreduce_sliding_window_start(ucc_coll_task_t *coll_task) pipe->my_count += count_total % size; } - ucc_tl_ucp_task_reset(task, UCC_INPROGRESS); + rdma_task->allreduce_sliding_window.reduce_task = NULL; - task->allreduce_sliding_window.reduce_task = NULL; - task->allreduce_sliding_window.barrier_task = NULL; + UCC_CHECK_GOTO(ucc_tl_ucp_allgather_ring_start( + rdma_task->allreduce_sliding_window.allgather_task), + out, status); - return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); + return ucc_schedule_start(coll_task); out: - ucc_tl_ucp_allreduce_sliding_window_free_task(coll_task); - ucc_tl_ucp_allreduce_sliding_window_free_pipe(coll_task); - ucc_tl_ucp_coll_finalize(task->allreduce_sliding_window.allgather_task); - tl_error(UCC_TASK_LIB(task), "failed to start allreduce sliding window: %s", ucc_status_string(status)); + tl_error(UCC_TASK_LIB(rdma_task), "failed to start allreduce sliding window: %s", + ucc_status_string(status)); return status; } ucc_status_t ucc_tl_ucp_allreduce_sliding_window_finalize(ucc_coll_task_t *coll_task) +{ + ucc_schedule_t *schedule = ucc_derived_of(coll_task, ucc_schedule_t); + ucc_status_t status; + + status = ucc_schedule_finalize(coll_task); + ucc_tl_ucp_put_schedule(schedule); + + return status; +} + +ucc_status_t +ucc_tl_ucp_allreduce_sliding_window_rdma_task_post( + ucc_coll_task_t *coll_task) +{ + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, + ucc_tl_ucp_task_t); + ucc_tl_ucp_team_t *team = TASK_TEAM(task); + + ucc_tl_ucp_task_reset(task, UCC_INPROGRESS); + + return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); +} + +static inline void ucc_tl_ucp_allreduce_sliding_window_free_rkeys( + ucc_coll_task_t *coll_task) +{ + int i; + ucc_base_team_t *team = coll_task->team; + ucc_rank_t team_size = (ucc_rank_t)team->params.size; + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + int inplace = UCC_IS_INPLACE(coll_task->bargs.args); + + for (i = 0; i < team_size; i++) { + if (!inplace) { + ucp_rkey_destroy(task->allreduce_sliding_window.bufs->src_rkeys[i]); + } + ucp_rkey_destroy(task->allreduce_sliding_window.bufs->dst_rkeys[i]); + } +} + +ucc_status_t +ucc_tl_ucp_allreduce_sliding_window_rdma_task_finalize( + ucc_coll_task_t *coll_task) { ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); ucc_status_t st = UCC_OK; + ucc_tl_ucp_allreduce_sliding_window_free_rkeys(coll_task); ucc_tl_ucp_allreduce_sliding_window_free_task(coll_task); ucc_tl_ucp_allreduce_sliding_window_free_pipe(coll_task); @@ -208,14 +199,20 @@ static inline void ucc_tl_ucp_allreduce_sliding_window_reduction( ucc_coll_task_t *coll_task, ucc_tl_ucp_allreduce_sw_buf_t *accbuf, ucc_tl_ucp_allreduce_sw_buf_t *getbuf) { + ucc_ee_executor_t *exec; ucc_status_t status = UCC_OK; ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); ucc_coll_args_t *args = &TASK_ARGS(task); ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype; + status = ucc_coll_task_get_executor(&task->super, &exec); + if (ucc_unlikely(status != UCC_OK)) { + tl_error(UCC_TASK_LIB(task), "failed to get executor"); + } + status = ucc_dt_reduce(accbuf->buf, getbuf->buf, accbuf->buf, accbuf->count, dt, - args, 0, 0, task->super.executor, + args, 0, 0, exec, &task->allreduce_sliding_window.reduce_task); if (ucc_unlikely(UCC_OK != status)) { @@ -290,56 +287,7 @@ static inline void ucc_tl_ucp_allreduce_sliding_window_key_exchange_progress( goto out; } -static inline void ucc_tl_ucp_allreduce_sliding_window_free_rkeys( - ucc_coll_task_t *coll_task) -{ - int i; - ucc_base_team_t *team = coll_task->team; - ucc_rank_t team_size = (ucc_rank_t)team->params.size; - ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); - int inplace = UCC_IS_INPLACE(coll_task->bargs.args); - - for (i = 0; i < team_size; i++) { - if (!inplace) { - ucp_rkey_destroy(task->allreduce_sliding_window.src_rkeys[i]); - } - ucp_rkey_destroy(task->allreduce_sliding_window.dst_rkeys[i]); - } -} - -static inline void -ucc_tl_ucp_allreduce_sliding_window_barrier(ucc_coll_task_t *coll_task) -{ - ucc_status_t status; - ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); - ucc_base_team_t *team = coll_task->team; - - ucc_base_coll_args_t coll_args = { - .team = coll_task->team->params.team, - .args.coll_type = UCC_COLL_TYPE_BARRIER, - }; - - status = ucc_tl_ucp_coll_init(&coll_args, team, - &task->allreduce_sliding_window.barrier_task); - if (status < 0) { - tl_error(coll_task->team->context->lib, - "failure during sliding window barrier init: %s", - ucc_status_string(status)); - task->super.status = status; - return; - } - - status = ucc_tl_ucp_barrier_knomial_start( - task->allreduce_sliding_window.barrier_task); - if (status < 0) { - tl_error(coll_task->team->context->lib, - "failure during sliding window barrier start: %s", - ucc_status_string(status)); - task->super.status = status; - } -} - -void ucc_tl_ucp_allreduce_sliding_window_progress(ucc_coll_task_t *coll_task) +void ucc_tl_ucp_allreduce_sliding_window_rdma_progress(ucc_coll_task_t *coll_task) { ucc_tl_ucp_allreduce_sw_buf_t *redbuf; ucc_tl_ucp_allreduce_sw_buf_t *getbuf; @@ -373,8 +321,6 @@ void ucc_tl_ucp_allreduce_sliding_window_progress(ucc_coll_task_t *coll_task) int i = 0; ucc_coll_task_t *allgather_task = task->allreduce_sliding_window.allgather_task; - ucc_coll_task_t *barrier_task = - task->allreduce_sliding_window.barrier_task; ucc_ee_executor_task_t **reduce_task = &task->allreduce_sliding_window.reduce_task; int put_window_size = @@ -383,20 +329,6 @@ void ucc_tl_ucp_allreduce_sliding_window_progress(ucc_coll_task_t *coll_task) ucc_assert(host_team_size > 0); - if (barrier_task != NULL) { - // mark sliding window task complete once barrier finishes - if (barrier_task->super.status == UCC_OK) { - ucc_tl_ucp_put_task( - ucc_derived_of(task->allreduce_sliding_window.barrier_task, - ucc_tl_ucp_task_t)); - task->allreduce_sliding_window.barrier_task = NULL; - task->super.status = UCC_OK; - } - - ucc_assert(barrier_task->super.status >= 0); - return; - } - if (allgather_task != NULL) { ucc_tl_ucp_allreduce_sliding_window_key_exchange_progress(coll_task); return; @@ -424,7 +356,8 @@ void ucc_tl_ucp_allreduce_sliding_window_progress(ucc_coll_task_t *coll_task) src_rank = pipe->src_rank; getbuf = accbuf->state == FREE ? accbuf : &pipe->getbuf[get_idx]; src_addr = (char*) - task->allreduce_sliding_window.sbufs[src_rank] + get_offset; + task->allreduce_sliding_window.bufs->sbufs[src_rank] + + get_offset; dst_addr = getbuf->buf; ucc_assert(getbuf->state == FREE); @@ -436,7 +369,8 @@ void ucc_tl_ucp_allreduce_sliding_window_progress(ucc_coll_task_t *coll_task) getbuf->ucp_req = ucp_get_nbx( ep, dst_addr, data_size, (uint64_t)src_addr, - task->allreduce_sliding_window.src_rkeys[src_rank], &req_param); + task->allreduce_sliding_window.bufs->src_rkeys[src_rank], + &req_param); pipe->src_rank = (src_rank + 1) % host_team_size; @@ -511,7 +445,8 @@ void ucc_tl_ucp_allreduce_sliding_window_progress(ucc_coll_task_t *coll_task) dst_rank = pipe->dst_rank; src_addr = accbuf->buf; dst_addr = (char*) - task->allreduce_sliding_window.rbufs[dst_rank] + put_offset; + task->allreduce_sliding_window.bufs->rbufs[dst_rank] + + put_offset; put_idx = pipe->posted_put % put_window_size; @@ -530,7 +465,7 @@ void ucc_tl_ucp_allreduce_sliding_window_progress(ucc_coll_task_t *coll_task) ucp_put_nbx( ep, src_addr, data_size, (uint64_t)dst_addr, - task->allreduce_sliding_window.dst_rkeys[dst_rank], + task->allreduce_sliding_window.bufs->dst_rkeys[dst_rank], &req_param); pipe->posted_put++; @@ -576,7 +511,6 @@ void ucc_tl_ucp_allreduce_sliding_window_progress(ucc_coll_task_t *coll_task) } if (pipe->count_serviced == pipe->my_count) { - ucc_tl_ucp_allreduce_sliding_window_barrier(coll_task); - ucc_tl_ucp_allreduce_sliding_window_free_rkeys(coll_task); + task->super.status = UCC_OK; } } diff --git a/src/components/tl/ucp/allreduce/allreduce_sliding_window.h b/src/components/tl/ucp/allreduce/allreduce_sliding_window.h index 7dd1d07635..431dac825d 100644 --- a/src/components/tl/ucp/allreduce/allreduce_sliding_window.h +++ b/src/components/tl/ucp/allreduce/allreduce_sliding_window.h @@ -8,13 +8,7 @@ #define ALLREDUCE_SW_H_ #include "tl_ucp_coll.h" - -#define ALLREDUCE_PACKED_KEY_MAX_LEN 1024 - -typedef struct ucc_tl_ucp_allreduce_sw_global_work_buf_info { - void *packed_src_memh; - void *packed_dst_memh; -} ucc_tl_ucp_allreduce_sw_global_work_buf_info_t; +#include "tl_ucp_dpu_offload.h" typedef enum ucc_tl_ucp_allreduce_sw_buf_state { FREE, @@ -56,21 +50,4 @@ typedef struct ucc_tl_ucp_allreduce_sw_pipeline { int posted_put; } ucc_tl_ucp_allreduce_sw_pipeline_t; -struct ucc_tl_ucp_allreduce_sw_export_buf { - ucp_context_h ucp_context; - ucp_mem_h memh; - void *packed_memh; - size_t packed_memh_len; - void *packed_key; - size_t packed_key_len; - uint64_t memh_id; -}; - -typedef struct ucc_tl_ucp_allreduce_sw_host_allgather { - void *src_buf; - void *dst_buf; - char packed_src_key[ALLREDUCE_PACKED_KEY_MAX_LEN]; - char packed_dst_key[ALLREDUCE_PACKED_KEY_MAX_LEN]; -} ucc_tl_ucp_allreduce_sw_host_allgather_t; - #endif diff --git a/src/components/tl/ucp/allreduce/allreduce_sliding_window_setup.c b/src/components/tl/ucp/allreduce/allreduce_sliding_window_setup.c index ae6ce98190..ce969116c6 100644 --- a/src/components/tl/ucp/allreduce/allreduce_sliding_window_setup.c +++ b/src/components/tl/ucp/allreduce/allreduce_sliding_window_setup.c @@ -69,6 +69,8 @@ ucc_tl_ucp_allreduce_sliding_window_alloc_pipe(ucc_base_team_t *team, } task->allreduce_sliding_window.pipe = pipe; + task->allreduce_sliding_window.put_requests = + task->allreduce_sliding_window.pipe->put_requests; return UCC_OK; @@ -104,38 +106,37 @@ ucc_tl_ucp_allreduce_sliding_window_task_init(ucc_base_coll_args_t *coll_args, ucc_assert(team_size > 0); - if (ucc_tl_ucp_allreduce_sliding_window_alloc_pipe(team, task) != UCC_OK) { + task->allreduce_sliding_window.bufs = + ucc_malloc(sizeof(ucc_tl_ucp_dpu_offload_buf_info_t)); + if (task->allreduce_sliding_window.bufs == NULL) { goto err; } allgather_data = ucc_malloc(allgather_size * (team_size + 1)); if (allgather_data == NULL) { - goto free_pipe; + goto free_bufs; } + task->allreduce_sliding_window.allgather_data = allgather_data; gwbi_p = coll_args->args.global_work_buffer; task->super.bargs.args.global_work_buffer = gwbi_p; - task->allreduce_sliding_window.barrier_task = NULL; task->allreduce_sliding_window.reduce_task = NULL; - task->allreduce_sliding_window.rbufs = + task->allreduce_sliding_window.bufs->rbufs = ucc_malloc(sizeof(void *) * team_size); - if (task->allreduce_sliding_window.rbufs == NULL) { + if (task->allreduce_sliding_window.bufs->rbufs == NULL) { goto free_allgather_data; } - task->allreduce_sliding_window.dst_rkeys = + task->allreduce_sliding_window.bufs->dst_rkeys = ucc_malloc(sizeof(ucp_rkey_h) * team_size); - if (task->allreduce_sliding_window.dst_rkeys == NULL) { + if (task->allreduce_sliding_window.bufs->dst_rkeys == NULL) { goto free_rbufs; } - task->allreduce_sliding_window.put_requests = - task->allreduce_sliding_window.pipe->put_requests; - - task->allreduce_sliding_window.dst_ebuf = + task->allreduce_sliding_window.bufs->dst_ebuf = ucc_malloc(sizeof(struct ucc_tl_ucp_allreduce_sw_export_buf)); - if (task->allreduce_sliding_window.dst_ebuf == NULL) { + if (task->allreduce_sliding_window.bufs->dst_ebuf == NULL) { goto free_dst_rkeys; } @@ -147,42 +148,42 @@ ucc_tl_ucp_allreduce_sliding_window_task_init(ucc_base_coll_args_t *coll_args, if (!inplace) { allgather_data->src_buf = src_buf; - task->allreduce_sliding_window.sbufs = + task->allreduce_sliding_window.bufs->sbufs = ucc_malloc(sizeof(void *) * team_size); - if (task->allreduce_sliding_window.sbufs == NULL) { + if (task->allreduce_sliding_window.bufs->sbufs == NULL) { goto free_dst_ebuf; } - task->allreduce_sliding_window.src_rkeys = + task->allreduce_sliding_window.bufs->src_rkeys = ucc_malloc(sizeof(ucp_rkey_h) * team_size); - if (task->allreduce_sliding_window.src_rkeys == NULL) { + if (task->allreduce_sliding_window.bufs->src_rkeys == NULL) { goto free_sbufs; } - task->allreduce_sliding_window.src_ebuf = + task->allreduce_sliding_window.bufs->src_ebuf = ucc_malloc(sizeof(struct ucc_tl_ucp_allreduce_sw_export_buf)); - if (task->allreduce_sliding_window.src_ebuf == NULL) { + if (task->allreduce_sliding_window.bufs->src_ebuf == NULL) { goto free_src_rkeys; } } else { - task->allreduce_sliding_window.src_ebuf = NULL; + task->allreduce_sliding_window.bufs->src_ebuf = NULL; } return UCC_OK; free_src_rkeys: - ucc_free(task->allreduce_sliding_window.src_rkeys); + ucc_free(task->allreduce_sliding_window.bufs->src_rkeys); free_sbufs: - ucc_free(task->allreduce_sliding_window.sbufs); + ucc_free(task->allreduce_sliding_window.bufs->sbufs); free_dst_ebuf: - ucc_free(task->allreduce_sliding_window.dst_ebuf); + ucc_free(task->allreduce_sliding_window.bufs->dst_ebuf); free_dst_rkeys: - ucc_free(task->allreduce_sliding_window.dst_rkeys); + ucc_free(task->allreduce_sliding_window.bufs->dst_rkeys); free_rbufs: - ucc_free(task->allreduce_sliding_window.rbufs); + ucc_free(task->allreduce_sliding_window.bufs->rbufs); free_allgather_data: ucc_free(allgather_data); -free_pipe: - ucc_tl_ucp_allreduce_sliding_window_free_pipe(&task->super); +free_bufs: + ucc_free(task->allreduce_sliding_window.bufs); err: tl_error(UCC_TL_TEAM_LIB(tl_team), "error while allocating task"); return UCC_ERR_NO_RESOURCE; @@ -219,9 +220,9 @@ ucc_status_t ucc_tl_ucp_allreduce_sliding_window_allgather_info_finalize( return UCC_ERR_NO_RESOURCE; } - sw_task->allreduce_sliding_window.rbufs[i] = + sw_task->allreduce_sliding_window.bufs->rbufs[i] = all_host_allgather[i].dst_buf; - sw_task->allreduce_sliding_window.dst_rkeys[i] = dst_unpacked; + sw_task->allreduce_sliding_window.bufs->dst_rkeys[i] = dst_unpacked; if (!inplace) { ucs_status = ucp_ep_rkey_unpack( @@ -231,14 +232,14 @@ ucc_status_t ucc_tl_ucp_allreduce_sliding_window_allgather_info_finalize( return UCC_ERR_NO_RESOURCE; } - sw_task->allreduce_sliding_window.sbufs[i] = + sw_task->allreduce_sliding_window.bufs->sbufs[i] = all_host_allgather[i].src_buf; - sw_task->allreduce_sliding_window.src_rkeys[i] = src_unpacked; + sw_task->allreduce_sliding_window.bufs->src_rkeys[i] = src_unpacked; } else { - sw_task->allreduce_sliding_window.sbufs = - sw_task->allreduce_sliding_window.rbufs; - sw_task->allreduce_sliding_window.src_rkeys = - sw_task->allreduce_sliding_window.dst_rkeys; + sw_task->allreduce_sliding_window.bufs->sbufs = + sw_task->allreduce_sliding_window.bufs->rbufs; + sw_task->allreduce_sliding_window.bufs->src_rkeys = + sw_task->allreduce_sliding_window.bufs->dst_rkeys; } } @@ -255,23 +256,25 @@ ucc_tl_ucp_allreduce_sliding_window_free_task(ucc_coll_task_t *coll_task) ucc_tl_ucp_context_t *tl_ctx = UCC_TL_UCP_TEAM_CTX(tl_team); if (!inplace) { - ucc_free(task->allreduce_sliding_window.sbufs); + ucc_free(task->allreduce_sliding_window.bufs->sbufs); } - ucc_free(task->allreduce_sliding_window.rbufs); + ucc_free(task->allreduce_sliding_window.bufs->rbufs); ucc_free(task->allreduce_sliding_window.allgather_data); if (!inplace) { ucp_mem_unmap(tl_ctx->worker.ucp_context, - task->allreduce_sliding_window.src_ebuf->memh); - ucc_free(task->allreduce_sliding_window.src_ebuf); - ucc_free(task->allreduce_sliding_window.src_rkeys); + task->allreduce_sliding_window.bufs->src_ebuf->memh); + ucc_free(task->allreduce_sliding_window.bufs->src_ebuf); + ucc_free(task->allreduce_sliding_window.bufs->src_rkeys); } ucp_mem_unmap(tl_ctx->worker.ucp_context, - task->allreduce_sliding_window.dst_ebuf->memh); - ucc_free(task->allreduce_sliding_window.dst_ebuf); - ucc_free(task->allreduce_sliding_window.dst_rkeys); + task->allreduce_sliding_window.bufs->dst_ebuf->memh); + ucc_free(task->allreduce_sliding_window.bufs->dst_ebuf); + ucc_free(task->allreduce_sliding_window.bufs->dst_rkeys); + + ucc_free(task->allreduce_sliding_window.bufs); } void diff --git a/src/components/tl/ucp/tl_ucp_coll.h b/src/components/tl/ucp/tl_ucp_coll.h index f1fbdb2c76..57f79f8784 100644 --- a/src/components/tl/ucp/tl_ucp_coll.h +++ b/src/components/tl/ucp/tl_ucp_coll.h @@ -92,6 +92,8 @@ typedef struct ucc_tl_ucp_allreduce_sw_pipeline ucc_tl_ucp_allreduce_sw_pipeline; typedef struct ucc_tl_ucp_allreduce_sw_host_allgather ucc_tl_ucp_allreduce_sw_host_allgather; +typedef struct ucc_tl_ucp_dpu_offload_buf_info + ucc_tl_ucp_dpu_offload_buf_info_t; typedef struct ucc_tl_ucp_task { ucc_coll_task_t super; @@ -128,18 +130,12 @@ typedef struct ucc_tl_ucp_task { } allreduce_kn; struct { int reduce_in_progress; - ucp_rkey_h *src_rkeys; //unpacked - ucp_rkey_h *dst_rkeys; //unpacked - void **sbufs; - void **rbufs; ucc_tl_ucp_allreduce_sw_pipeline *pipe; - ucc_ee_executor_task_t *reduce_task; ucs_status_ptr_t *put_requests; - ucc_coll_task_t *allgather_task; ucc_tl_ucp_allreduce_sw_host_allgather *allgather_data; - ucc_coll_task_t *barrier_task; - struct ucc_tl_ucp_allreduce_sw_export_buf *src_ebuf; - struct ucc_tl_ucp_allreduce_sw_export_buf *dst_ebuf; + ucc_coll_task_t *allgather_task; + ucc_ee_executor_task_t *reduce_task; + ucc_tl_ucp_dpu_offload_buf_info_t *bufs; } allreduce_sliding_window; struct { int phase; diff --git a/src/components/tl/ucp/tl_ucp_dpu_offload.c b/src/components/tl/ucp/tl_ucp_dpu_offload.c new file mode 100644 index 0000000000..938cac9408 --- /dev/null +++ b/src/components/tl/ucp/tl_ucp_dpu_offload.c @@ -0,0 +1,42 @@ +/** + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include "tl_ucp_dpu_offload.h" + +ucc_status_t ucc_tl_ucp_allreduce_sliding_window_register( + ucp_context_h ucp_context, ucc_tl_ucp_team_t *tl_team, + struct ucc_tl_ucp_allreduce_sw_export_buf *ebuf, void *packed_memh) +{ + ucs_status_t ucs_status; + ucp_mem_map_params_t params = {0}; + + ebuf->ucp_context = ucp_context; + + params.field_mask = UCP_MEM_MAP_PARAM_FIELD_EXPORTED_MEMH_BUFFER; + params.exported_memh_buffer = packed_memh; + + ucs_status = ucp_mem_map(ucp_context, ¶ms, &ebuf->memh); + if (UCS_OK != ucs_status) { + tl_error(UCC_TL_TEAM_LIB(tl_team), + "import using ucp_mem_map() returned error: %s\n", + ucs_status_string(ucs_status)); + return ucs_status_to_ucc_status(ucs_status); + } + + ucs_status = ucp_rkey_pack(ucp_context, ebuf->memh, &ebuf->packed_key, + &ebuf->packed_key_len); + if (UCS_OK != ucs_status) { + ucs_status_t unmap_status = ucp_mem_unmap(ucp_context, ebuf->memh); + tl_error(UCC_TL_TEAM_LIB(tl_team), + "ucp_rkey_pack() returned error: %s%s\n", + ucs_status_string(ucs_status), + unmap_status == UCS_OK ? "" : + ". While handling this error, unmapping the memh had an error\n"); + return ucs_status_to_ucc_status(ucs_status); + } + + return UCC_OK; +} \ No newline at end of file diff --git a/src/components/tl/ucp/tl_ucp_dpu_offload.h b/src/components/tl/ucp/tl_ucp_dpu_offload.h new file mode 100644 index 0000000000..940ed7bb40 --- /dev/null +++ b/src/components/tl/ucp/tl_ucp_dpu_offload.h @@ -0,0 +1,55 @@ +/** + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#ifndef UCC_TL_UCP_DPU_OFFLOAD_H_ +#define UCC_TL_UCP_DPU_OFFLOAD_H_ + +#include "tl_ucp.h" +#include "schedule/ucc_schedule_pipelined.h" +#include "components/mc/base/ucc_mc_base.h" +#include "components/ec/ucc_ec.h" +#include "tl_ucp_tag.h" + + +#define ALLREDUCE_PACKED_KEY_MAX_LEN 1024 + +typedef struct ucc_tl_ucp_allreduce_sw_global_work_buf_info { + void *packed_src_memh; + void *packed_dst_memh; +} ucc_tl_ucp_allreduce_sw_global_work_buf_info_t; + +struct ucc_tl_ucp_allreduce_sw_export_buf { + ucp_context_h ucp_context; + ucp_mem_h memh; + void *packed_memh; + size_t packed_memh_len; + void *packed_key; + size_t packed_key_len; + uint64_t memh_id; +}; + +typedef struct ucc_tl_ucp_allreduce_sw_host_allgather { + void *src_buf; + void *dst_buf; + char packed_src_key[ALLREDUCE_PACKED_KEY_MAX_LEN]; + char packed_dst_key[ALLREDUCE_PACKED_KEY_MAX_LEN]; +} ucc_tl_ucp_allreduce_sw_host_allgather_t; + +typedef struct ucc_tl_ucp_dpu_offload_buf_info { + ucp_rkey_h *src_rkeys; //unpacked + ucp_rkey_h *dst_rkeys; //unpacked + void **sbufs; + void **rbufs; + struct ucc_tl_ucp_allreduce_sw_export_buf *src_ebuf; + struct ucc_tl_ucp_allreduce_sw_export_buf *dst_ebuf; +} ucc_tl_ucp_dpu_offload_buf_info_t; + +ucc_status_t ucc_tl_ucp_allreduce_sliding_window_register( + ucp_context_h ucp_context, ucc_tl_ucp_team_t *tl_team, + struct ucc_tl_ucp_allreduce_sw_export_buf *ebuf, void *packed_memh); + + +#endif