From 82d310845a142493ea214055d99fee4fb30de330 Mon Sep 17 00:00:00 2001 From: Nick Sarkauskas Date: Wed, 7 Feb 2024 15:09:33 -0500 Subject: [PATCH] EC/CPU: Fix int8 reductions not getting vectorized (#918) Co-authored-by: Nick Sarkauskas --- src/components/ec/cpu/ec_cpu.c | 8 ++++++-- src/components/ec/cpu/ec_cpu.h | 2 +- src/components/ec/cpu/ec_cpu_reduce.c | 23 +++++++++++------------ 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/src/components/ec/cpu/ec_cpu.c b/src/components/ec/cpu/ec_cpu.c index 8d08d2365c..9ba29c05d5 100644 --- a/src/components/ec/cpu/ec_cpu.c +++ b/src/components/ec/cpu/ec_cpu.c @@ -114,7 +114,11 @@ ucc_status_t ucc_cpu_executor_task_post(ucc_ee_executor_t *executor, switch (task_args->task_type) { case UCC_EE_EXECUTOR_TASK_REDUCE: status = ucc_ec_cpu_reduce((ucc_eee_task_reduce_t *)&task_args->reduce, - task_args->flags); + (task_args->flags & + UCC_EEE_TASK_FLAG_REDUCE_SRCS_EXT) ? + task_args->reduce.srcs_ext : + task_args->reduce.srcs, + task_args->flags); if (ucc_unlikely(UCC_OK != status)) { goto free_task; } @@ -147,7 +151,7 @@ ucc_status_t ucc_cpu_executor_task_post(ucc_ee_executor_t *executor, tr.dst = trs->dst; tr.alpha = trs->alpha; - status = ucc_ec_cpu_reduce(&tr, flags); + status = ucc_ec_cpu_reduce(&tr, srcs, flags); if (ucc_unlikely(UCC_OK != status)) { goto free_task; } diff --git a/src/components/ec/cpu/ec_cpu.h b/src/components/ec/cpu/ec_cpu.h index 0c6b63c92c..bbbaf0b341 100644 --- a/src/components/ec/cpu/ec_cpu.h +++ b/src/components/ec/cpu/ec_cpu.h @@ -25,5 +25,5 @@ typedef struct ucc_ec_cpu { extern ucc_ec_cpu_t ucc_ec_cpu; -ucc_status_t ucc_ec_cpu_reduce(ucc_eee_task_reduce_t *task, uint16_t flags); +ucc_status_t ucc_ec_cpu_reduce(ucc_eee_task_reduce_t *task, void * const * restrict srcs, uint16_t flags); #endif diff --git a/src/components/ec/cpu/ec_cpu_reduce.c b/src/components/ec/cpu/ec_cpu_reduce.c index a805d67294..8c065c6f24 100644 --- a/src/components/ec/cpu/ec_cpu_reduce.c +++ b/src/components/ec/cpu/ec_cpu_reduce.c @@ -12,48 +12,49 @@ do { \ size_t _i, _j; \ type _tmp; \ + size_t __count = _count; \ switch (_n_srcs) { \ case 2: \ - for (_i = 0; _i < _count; _i++) { \ + for (_i = 0; _i < __count; _i++) { \ d[_i] = OP##_2(s[0][_i], s[1][_i]); \ } \ break; \ case 3: \ - for (_i = 0; _i < _count; _i++) { \ + for (_i = 0; _i < __count; _i++) { \ d[_i] = OP##_3(s[0][_i], s[1][_i], s[2][_i]); \ } \ break; \ case 4: \ - for (_i = 0; _i < _count; _i++) { \ + for (_i = 0; _i < __count; _i++) { \ d[_i] = OP##_4(s[0][_i], s[1][_i], s[2][_i], s[3][_i]); \ } \ break; \ case 5: \ - for (_i = 0; _i < _count; _i++) { \ + for (_i = 0; _i < __count; _i++) { \ d[_i] = \ OP##_5(s[0][_i], s[1][_i], s[2][_i], s[3][_i], s[4][_i]); \ } \ break; \ case 6: \ - for (_i = 0; _i < _count; _i++) { \ + for (_i = 0; _i < __count; _i++) { \ d[_i] = OP##_6(s[0][_i], s[1][_i], s[2][_i], s[3][_i], \ s[4][_i], s[5][_i]); \ } \ break; \ case 7: \ - for (_i = 0; _i < _count; _i++) { \ + for (_i = 0; _i < __count; _i++) { \ d[_i] = OP##_7(s[0][_i], s[1][_i], s[2][_i], s[3][_i], \ s[4][_i], s[5][_i], s[6][_i]); \ } \ break; \ case 8: \ - for (_i = 0; _i < _count; _i++) { \ + for (_i = 0; _i < __count; _i++) { \ d[_i] = OP##_8(s[0][_i], s[1][_i], s[2][_i], s[3][_i], \ s[4][_i], s[5][_i], s[6][_i], s[7][_i]); \ } \ break; \ default: \ - for (_i = 0; _i < _count; _i++) { \ + for (_i = 0; _i < __count; _i++) { \ _tmp = OP##_8(s[0][_i], s[1][_i], s[2][_i], s[3][_i], \ s[4][_i], s[5][_i], s[6][_i], s[7][_i]); \ for (_j = 8; _j < _n_srcs; _j++) { \ @@ -223,11 +224,9 @@ } \ } while (0) -ucc_status_t ucc_ec_cpu_reduce(ucc_eee_task_reduce_t *task, uint16_t flags) +ucc_status_t ucc_ec_cpu_reduce(ucc_eee_task_reduce_t *task, + void * const * restrict srcs, uint16_t flags) { - void **srcs = (flags & UCC_EEE_TASK_FLAG_REDUCE_SRCS_EXT) ? task->srcs_ext - : task->srcs; - switch (task->dt) { case UCC_DT_INT8: DO_DT_REDUCE_INT(int8_t, srcs, task->dst, task->op, task->count,