diff --git a/src/components/tl/cuda/bcast/bcast_linear.c b/src/components/tl/cuda/bcast/bcast_linear.c index f6712c2c46..743ddab245 100644 --- a/src/components/tl/cuda/bcast/bcast_linear.c +++ b/src/components/tl/cuda/bcast/bcast_linear.c @@ -62,6 +62,17 @@ ucc_status_t ucc_tl_cuda_bcast_linear_setup_test(ucc_tl_cuda_task_t *task) return ucc_tl_cuda_shm_barrier_test(UCC_TL_TEAM_RANK(team), task->bar); } +static inline size_t get_scratch_size(ucc_tl_cuda_team_t *team, + ucc_datatype_t dt) +{ + size_t dt_size = ucc_dt_size(dt); + ucc_rank_t tsize = UCC_TL_TEAM_SIZE(team); + + ucc_assert((dt_size > 0) && (tsize > 0)); + + return UCC_TL_CUDA_TEAM_LIB(team)->cfg.scratch_size; +} + static inline ucc_status_t ecopy(void *dst, void *src, size_t size, ucc_ee_executor_t *exec, ucc_ee_executor_task_t **etask) @@ -90,6 +101,7 @@ void ucc_tl_cuda_bcast_linear_progress(ucc_coll_task_t *coll_task) ucc_tl_cuda_team_t *team = TASK_TEAM(task); ucc_rank_t trank = UCC_TL_TEAM_RANK(team); ucc_rank_t tsize = UCC_TL_TEAM_SIZE(team); + ucc_datatype_t dt = task->bcast_linear.dt; ucc_status_t st; (void)team; (void)st; @@ -136,14 +148,23 @@ void ucc_tl_cuda_bcast_linear_progress(ucc_coll_task_t *coll_task) break; } + size_t scratch_size = get_scratch_size(team, dt); + size_t chunk_size = task->bcast_linear.step < task->bcast_linear.num_steps + ? ucc_min(scratch_size, task->bcast_linear.size) + : task->bcast_linear.size - + (task->bcast_linear.step - 1) * scratch_size; + size_t offset_buff = task->bcast_linear.step * scratch_size; + + // ucc_info("chunk_size: %ld", chunk_size); + if (trank == task->bcast_linear.root) { // fall-through between cases is intentional switch (task->bcast_linear.stage) { case STAGE_COPY: // copy from src buffer to scratch dbuf = TASK_SCRATCH(task, trank); - sbuf = task->bcast_linear.sbuf; - status = ecopy(dbuf, sbuf, task->bcast_linear.size, exec, + sbuf = PTR_OFFSET(task->bcast_linear.sbuf, offset_buff); + status = ecopy(dbuf, sbuf, chunk_size, exec, &task->bcast_linear.exec_task); task->bcast_linear.stage = STAGE_WAIT_COPY; break; @@ -156,7 +177,8 @@ void ucc_tl_cuda_bcast_linear_progress(ucc_coll_task_t *coll_task) task->bcast_linear.exec_task = NULL; // signal others ++task->bcast_linear.step; - set_rank_step(task, task->bcast_linear.root, task->bcast_linear.step, 0); + set_rank_step(task, task->bcast_linear.root, + task->bcast_linear.step, 0); task->bcast_linear.stage = STAGE_WAIT_ALL; } } @@ -173,8 +195,11 @@ void ucc_tl_cuda_bcast_linear_progress(ucc_coll_task_t *coll_task) } task->bcast_linear.stage = STAGE_COPY; // ucc_info("all others ready for next step"); - // TODO: remove - task->bcast_linear.stage = STAGE_DONE; + if (task->bcast_linear.stage < task->bcast_linear.num_steps) { + task->bcast_linear.stage = STAGE_COPY; + } else { + task->bcast_linear.stage = STAGE_DONE; + } break; case STAGE_DONE: task->super.status = UCC_OK; @@ -196,11 +221,13 @@ void ucc_tl_cuda_bcast_linear_progress(ucc_coll_task_t *coll_task) } break; case STAGE_CLIENT_COPY: - dbuf = task->bcast_linear.sbuf; - sbuf = TASK_SCRATCH(task, - task->bcast_linear.root); // need to copy from root's scratch buffer - status = ecopy(dbuf, sbuf, task->bcast_linear.size, exec, - &task->bcast_linear.exec_task); + dbuf = PTR_OFFSET(task->bcast_linear.sbuf, offset_buff); + sbuf = TASK_SCRATCH( + task, + task->bcast_linear + .root); // need to copy from root's scratch buffer + status = ecopy(dbuf, sbuf, chunk_size, exec, + &task->bcast_linear.exec_task); task->bcast_linear.stage = STAGE_CLIENT_COPY_WAIT; break; case STAGE_CLIENT_COPY_WAIT: @@ -212,8 +239,14 @@ void ucc_tl_cuda_bcast_linear_progress(ucc_coll_task_t *coll_task) task->bcast_linear.exec_task = NULL; ++task->bcast_linear.step; set_rank_step(task, trank, task->bcast_linear.step, 0); - task->bcast_linear.stage = - STAGE_DONE; // TODO: just for debug + // task->bcast_linear.stage = + // STAGE_DONE; // TODO: just for debug + if (task->bcast_linear.stage < + task->bcast_linear.num_steps) { + task->bcast_linear.stage = STAGE_COPY; + } else { + task->bcast_linear.stage = STAGE_DONE; + } } } break; @@ -235,7 +268,6 @@ ucc_status_t ucc_tl_cuda_bcast_linear_start(ucc_coll_task_t *coll_task) ucc_datatype_t dt = task->bcast_linear.dt; (void)tsize; - (void)args; (void)dt; task->bcast_linear.stage = STAGE_SYNC; @@ -244,8 +276,12 @@ ucc_status_t ucc_tl_cuda_bcast_linear_start(ucc_coll_task_t *coll_task) args->src.info.count); task->bcast_linear.size = ucc_dt_size(dt) * args->src.info.count; + size_t scratch_size = get_scratch_size(team, dt); + task->bcast_linear.num_steps = + ucc_div_round_up(task->bcast_linear.size, scratch_size); - ucc_info("bcast buffer size: %ld", task->bcast_linear.size); + ucc_info("bcast buffer size: %ld, num_steps: %d", task->bcast_linear.size, + task->bcast_linear.num_steps); task->bcast_linear.sbuf = args->src.info.buffer; task->bcast_linear.step = 0; @@ -275,7 +311,7 @@ ucc_status_t ucc_tl_cuda_bcast_linear_init(ucc_base_coll_args_t *coll_args, } task->bcast_linear.root = coll_args->args.root; - task->bcast_linear.dt = coll_args->args.src.info.datatype; + task->bcast_linear.dt = coll_args->args.src.info.datatype; ucc_info("bcast init with dt: %s", ucc_datatype_str(task->bcast_linear.dt)); task->bcast_linear.sbuf = coll_args->args.src.info.buffer; diff --git a/src/components/tl/cuda/tl_cuda.h b/src/components/tl/cuda/tl_cuda.h index 38608973ac..096bab1a0e 100644 --- a/src/components/tl/cuda/tl_cuda.h +++ b/src/components/tl/cuda/tl_cuda.h @@ -233,6 +233,7 @@ struct ucc_tl_cuda_task { ucc_datatype_t dt; ucc_rank_t root; size_t size; + int num_steps; ucc_ee_executor_task_t *exec_task; } bcast_linear; struct {