Skip to content

Commit

Permalink
EC/CPU: Fix int8 reductions not getting vectorized
Browse files Browse the repository at this point in the history
  • Loading branch information
nsarka committed Feb 7, 2024
1 parent c13d26c commit 0796aff
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 15 deletions.
8 changes: 6 additions & 2 deletions src/components/ec/cpu/ec_cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,11 @@ ucc_status_t ucc_cpu_executor_task_post(ucc_ee_executor_t *executor,
switch (task_args->task_type) {
case UCC_EE_EXECUTOR_TASK_REDUCE:
status = ucc_ec_cpu_reduce((ucc_eee_task_reduce_t *)&task_args->reduce,
task_args->flags);
(task_args->flags &
UCC_EEE_TASK_FLAG_REDUCE_SRCS_EXT) ?
task_args->reduce.srcs_ext :
task_args->reduce.srcs,
task_args->flags);
if (ucc_unlikely(UCC_OK != status)) {
goto free_task;
}
Expand Down Expand Up @@ -147,7 +151,7 @@ ucc_status_t ucc_cpu_executor_task_post(ucc_ee_executor_t *executor,
tr.dst = trs->dst;
tr.alpha = trs->alpha;

status = ucc_ec_cpu_reduce(&tr, flags);
status = ucc_ec_cpu_reduce(&tr, srcs, flags);
if (ucc_unlikely(UCC_OK != status)) {
goto free_task;
}
Expand Down
2 changes: 1 addition & 1 deletion src/components/ec/cpu/ec_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@ typedef struct ucc_ec_cpu {

extern ucc_ec_cpu_t ucc_ec_cpu;

ucc_status_t ucc_ec_cpu_reduce(ucc_eee_task_reduce_t *task, uint16_t flags);
ucc_status_t ucc_ec_cpu_reduce(ucc_eee_task_reduce_t *task, void * const * restrict srcs, uint16_t flags);
#endif
23 changes: 11 additions & 12 deletions src/components/ec/cpu/ec_cpu_reduce.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,48 +12,49 @@
do { \
size_t _i, _j; \
type _tmp; \
size_t __count = _count; \
switch (_n_srcs) { \
case 2: \
for (_i = 0; _i < _count; _i++) { \
for (_i = 0; _i < __count; _i++) { \
d[_i] = OP##_2(s[0][_i], s[1][_i]); \
} \
break; \
case 3: \
for (_i = 0; _i < _count; _i++) { \
for (_i = 0; _i < __count; _i++) { \
d[_i] = OP##_3(s[0][_i], s[1][_i], s[2][_i]); \
} \
break; \
case 4: \
for (_i = 0; _i < _count; _i++) { \
for (_i = 0; _i < __count; _i++) { \
d[_i] = OP##_4(s[0][_i], s[1][_i], s[2][_i], s[3][_i]); \
} \
break; \
case 5: \
for (_i = 0; _i < _count; _i++) { \
for (_i = 0; _i < __count; _i++) { \
d[_i] = \
OP##_5(s[0][_i], s[1][_i], s[2][_i], s[3][_i], s[4][_i]); \
} \
break; \
case 6: \
for (_i = 0; _i < _count; _i++) { \
for (_i = 0; _i < __count; _i++) { \
d[_i] = OP##_6(s[0][_i], s[1][_i], s[2][_i], s[3][_i], \
s[4][_i], s[5][_i]); \
} \
break; \
case 7: \
for (_i = 0; _i < _count; _i++) { \
for (_i = 0; _i < __count; _i++) { \
d[_i] = OP##_7(s[0][_i], s[1][_i], s[2][_i], s[3][_i], \
s[4][_i], s[5][_i], s[6][_i]); \
} \
break; \
case 8: \
for (_i = 0; _i < _count; _i++) { \
for (_i = 0; _i < __count; _i++) { \
d[_i] = OP##_8(s[0][_i], s[1][_i], s[2][_i], s[3][_i], \
s[4][_i], s[5][_i], s[6][_i], s[7][_i]); \
} \
break; \
default: \
for (_i = 0; _i < _count; _i++) { \
for (_i = 0; _i < __count; _i++) { \
_tmp = OP##_8(s[0][_i], s[1][_i], s[2][_i], s[3][_i], \
s[4][_i], s[5][_i], s[6][_i], s[7][_i]); \
for (_j = 8; _j < _n_srcs; _j++) { \
Expand Down Expand Up @@ -223,11 +224,9 @@
} \
} while (0)

ucc_status_t ucc_ec_cpu_reduce(ucc_eee_task_reduce_t *task, uint16_t flags)
ucc_status_t ucc_ec_cpu_reduce(ucc_eee_task_reduce_t *task,
void * const * restrict srcs, uint16_t flags)
{
void **srcs = (flags & UCC_EEE_TASK_FLAG_REDUCE_SRCS_EXT) ? task->srcs_ext
: task->srcs;

switch (task->dt) {
case UCC_DT_INT8:
DO_DT_REDUCE_INT(int8_t, srcs, task->dst, task->op, task->count,
Expand Down

0 comments on commit 0796aff

Please sign in to comment.