Skip to content

Commit

Permalink
Add inner loop unrolling for f32 GEMM on aarch64
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707440292
  • Loading branch information
alankelly authored and xnnpack-bot committed Dec 18, 2024
1 parent 9c682e5 commit 2428c72
Show file tree
Hide file tree
Showing 117 changed files with 9,382 additions and 2,100 deletions.
379 changes: 349 additions & 30 deletions bench/f32-gemm-minmax.cc

Large diffs are not rendered by default.

39 changes: 34 additions & 5 deletions cmake/gen/aarch64_microkernels.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -123,38 +123,67 @@ SET(NON_PROD_AARCH64_ASM_MICROKERNEL_SRCS
src/f32-dwconv/f32-dwconv-9p4c-minmax-asm-aarch64-neonfma.S
src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neon-ld128-acc2-prfm.S
src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neon-ld128-acc2.S
src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld32.S
src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld32-2.S
src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld64-2.S
src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld64-acc2-prfm.S
src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld64-acc2.S
src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld64-acc4-prfm.S
src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld64-acc4.S
src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld64-prfm.S
src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld64.S
src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld128-2.S
src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld128-acc2-prfm.S
src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld128-acc4-prfm.S
src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld128-prfm.S
src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld128.S
src/f32-gemm/gen/f32-gemm-1x12-minmax-asm-aarch64-neonfma-cortex-a53.S
src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-aarch64-neonfma-ld32.S
src/f32-gemm/gen/f32-gemm-2x8-minmax-asm-aarch64-neonfma-ld32.S
src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-aarch64-neonfma-ld64.S
src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-aarch64-neonfma-ld128.S
src/f32-gemm/gen/f32-gemm-2x8-minmax-asm-aarch64-neonfma-ld32-2.S
src/f32-gemm/gen/f32-gemm-2x8-minmax-asm-aarch64-neonfma-ld64-2.S
src/f32-gemm/gen/f32-gemm-2x8-minmax-asm-aarch64-neonfma-ld128-2.S
src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-aarch64-neonfma-ld32.S
src/f32-gemm/gen/f32-gemm-3x8-minmax-asm-aarch64-neonfma-ld32.S
src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-aarch64-neonfma-ld64.S
src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-aarch64-neonfma-ld128.S
src/f32-gemm/gen/f32-gemm-3x8-minmax-asm-aarch64-neonfma-ld32-2.S
src/f32-gemm/gen/f32-gemm-3x8-minmax-asm-aarch64-neonfma-ld64-2.S
src/f32-gemm/gen/f32-gemm-3x8-minmax-asm-aarch64-neonfma-ld128-2.S
src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-aarch64-neonfma-ld32.S
src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-aarch64-neonfma-ld64.S
src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-aarch64-neonfma-ld128.S
src/f32-gemm/gen/f32-gemm-4x1-minmax-asm-aarch64-neonfma-ld64.S
src/f32-gemm/gen/f32-gemm-4x1-minmax-asm-aarch64-neonfma-ld128.S
src/f32-gemm/gen/f32-gemm-4x2-minmax-asm-aarch64-neonfma-cortex-a75-prfm.S
src/f32-gemm/gen/f32-gemm-4x2-minmax-asm-aarch64-neonfma-cortex-a75.S
src/f32-gemm/gen/f32-gemm-4x2-minmax-asm-aarch64-neonfma-ld64.S
src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch64-neonfma-ld32.S
src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch64-neonfma-ld32-2.S
src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch64-neonfma-ld64-2.S
src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch64-neonfma-ld64.S
src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch64-neonfma-ld128-2.S
src/f32-gemm/gen/f32-gemm-4x12-minmax-asm-aarch64-neonfma-cortex-a53.S
src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-aarch64-neonfma-ld32.S
src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-aarch64-neonfma-ld64.S
src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-aarch64-neonfma-ld128.S
src/f32-gemm/gen/f32-gemm-5x8-minmax-asm-aarch64-neonfma-cortex-a75-prfm.S
src/f32-gemm/gen/f32-gemm-5x8-minmax-asm-aarch64-neonfma-cortex-a75.S
src/f32-gemm/gen/f32-gemm-5x8-minmax-asm-aarch64-neonfma-ld32.S
src/f32-gemm/gen/f32-gemm-5x8-minmax-asm-aarch64-neonfma-ld32-2.S
src/f32-gemm/gen/f32-gemm-5x8-minmax-asm-aarch64-neonfma-ld64-2.S
src/f32-gemm/gen/f32-gemm-5x8-minmax-asm-aarch64-neonfma-ld128-2.S
src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-aarch64-neonfma-ld32.S
src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-aarch64-neonfma-ld64.S
src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-aarch64-neonfma-ld128.S
src/f32-gemm/gen/f32-gemm-6x8-minmax-asm-aarch64-neonfma-cortex-a75.S
src/f32-gemm/gen/f32-gemm-6x8-minmax-asm-aarch64-neonfma-ld32-2.S
src/f32-gemm/gen/f32-gemm-6x8-minmax-asm-aarch64-neonfma-ld64-2.S
src/f32-gemm/gen/f32-gemm-6x8-minmax-asm-aarch64-neonfma-ld64.S
src/f32-gemm/gen/f32-gemm-6x8-minmax-asm-aarch64-neonfma-ld128-2.S
src/f32-gemm/gen/f32-gemm-7x8-minmax-asm-aarch64-neonfma-ld32-2.S
src/f32-gemm/gen/f32-gemm-7x8-minmax-asm-aarch64-neonfma-ld64-2.S
src/f32-gemm/gen/f32-gemm-7x8-minmax-asm-aarch64-neonfma-ld128-2.S
src/f32-gemm/gen/f32-gemm-8x8-minmax-asm-aarch64-neonfma-ld32-2.S
src/f32-gemm/gen/f32-gemm-8x8-minmax-asm-aarch64-neonfma-ld64-2.S
src/f32-gemm/gen/f32-gemm-8x8-minmax-asm-aarch64-neonfma-ld128-2.S
src/f32-gemm/gen/f32-gemm-goi-1x8-minmax-asm-aarch64-neonfma-ld128-prfm.S
src/f32-gemm/gen/f32-gemm-goi-1x8-minmax-asm-aarch64-neonfma-ld128.S
src/f32-gemm/gen/f32-gemm-goi-4x8-minmax-asm-aarch64-neonfma-ld128.S
Expand Down
131 changes: 87 additions & 44 deletions gemm_compiler/aarch64_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@

class Aarch64(base_architecture.BaseArchitecture):

def __init__(self):
self.decrement = 4
self.unroll_factor = 1

def astride_register(self):
return 'x4'

Expand All @@ -24,7 +28,7 @@ def cm_stride_register(self):
return 'x7'

def am_registers(self):
return [self.a_ptr_register()] + ['x9', 'x10', 'x11', 'x12', 'x21', 'x22']
return [self.a_ptr_register()] + ['x9', 'x10', 'x11', 'x12', 'x21', 'x22', 'x25']

def a_ptr_register(self):
return 'x3'
Expand All @@ -33,7 +37,7 @@ def c_ptr_register(self):
return 'x6'

def cm_registers(self):
return [self.c_ptr_register()] + ['x13', 'x14', 'x15', 'x19', 'x23', 'x24']
return [self.c_ptr_register()] + ['x13', 'x14', 'x15', 'x19', 'x23', 'x24', 'x26']

def w_ptr_register(self):
return 'x5'
Expand Down Expand Up @@ -102,7 +106,7 @@ def register_map_dword(self, reg):
return map[reg]

def function_name(self, M, N, isa):
return f'xnn_f32_gemm_minmax_ukernel_{M}x{N}__asm_aarch64_{isa}_lane\n'
return f'xnn_f32_gemm_minmax_ukernel_{M}x{N}__asm_aarch64_{isa}_ld32\n'

def quantization_params(self):
return ''
Expand All @@ -113,15 +117,16 @@ def header(self, M, N, prefix, isa):
HEADER += 'BEGIN_FUNCTION ' + self.function_name(M, N, isa)
HEADER += """
# Free up GP registers.
stp x19, x20, [sp, -48]
stp x21, x22, [sp, -32]
stp x23, x24, [sp, -16]
stp x19, x20, [sp, -64]
stp x21, x22, [sp, -48]
stp x23, x24, [sp, -32]
stp x25, x26, [sp, -16]
# Preserve callee saved q8-q15 registers.
stp q8, q9, [sp, -176]
stp q10, q11, [sp, -144]
stp q12, q13, [sp, -112]
stp q14, q15, [sp, -80]
stp d8, d9, [sp, -128]
stp d10, d11, [sp, -112]
stp d12, d13, [sp, -96]
stp d14, d15, [sp, -80]
# Load params.
ldr x13, [sp, 8]
Expand All @@ -137,26 +142,11 @@ def jump_to_label(self, label):
def read_a_registers(self, M):
return ''

def inner_loop(self, M, N):
def do_loop(self, M, N, i):
N_COUNT = N // self.n_step()
asm_string = '\ninner_loop:\n'
if 'before' in self.input_asm():
asm_string += self.input_asm()['before']
for mr in range(0, M):
for l in self.input_asm()['loop']:
asm_string += l.format(
AM_ptr=self.am_registers()[mr],
AM=self.a_registers(mr),
a_offset=self.k_register(),
)
if 'after' in self.input_asm():
asm_string += self.input_asm()['after']

# weights
if 'before' in self.weights_asm():
asm_string += self.weights_asm()['before']
asm_string = ''
for l in self.weights_asm()['loop_2']:
for nr in range(0, N_COUNT, 2):
for nr in range(0, N_COUNT - 1, 2):
asm_string += l.format(
W_ptr=self.w_ptr_register(),
W=self.w_registers()[nr],
Expand All @@ -165,7 +155,7 @@ def inner_loop(self, M, N):
w_step=self.register_bytes() * N_COUNT,
)
for l in self.weights_asm()['loop']:
if N_COUNT % 2 == 0:
if N_COUNT % 2 != 0:
asm_string += l.format(
W_ptr=self.w_ptr_register(),
W=self.w_registers()[nr],
Expand All @@ -184,9 +174,63 @@ def inner_loop(self, M, N):
W=self.w_registers()[nr],
A=self.a_registers(mr),
ACC=self.acc_registers()[M * nr + mr],
POS=i,
)
return asm_string

def inner_loop(self, M, N):
asm_string = ''
if self.unroll_factor > 1:
asm_string += '# Are there at least {DECREMENT} bytes?\n'.format(DECREMENT=self.unroll_factor * 4)
asm_string += 'cmp {k_register}, {DECREMENT}\n'.format(k_register=self.k_register(), DECREMENT=self.unroll_factor * 4)
asm_string += 'blt inner_loop_tail\n'
asm_string += 'sub {k_register}, {k_register}, {DECREMENT}\n'.format(k_register=self.k_register(), DECREMENT=self.unroll_factor * 4)

asm_string += '\ninner_loop:\n'
decrement = 4 * self.unroll_factor
if 'before' in self.input_asm():
asm_string += self.input_asm()['before']
for mr in range(0, M):
for l in self.input_asm()['loop']:
asm_string += l.format(
AM_ptr=self.am_registers()[mr],
AM=self.a_registers(mr),
a_offset=self.k_register(),
)
if 'after' in self.input_asm():
asm_string += self.input_asm()['after']

# weights
if 'before' in self.weights_asm():
asm_string += self.weights_asm()['before']
inner_loop_label = 'inner_loop'
if self.unroll_factor > 1:
for u in range(self.unroll_factor):
asm_string += self.do_loop(M, N, u)
# loop counter
asm_string += self.cmp_k_and_jump_if_less(label=inner_loop_label, decrement=decrement, cond='bhs')

asm_string += f'''
add x20, x20, {decrement}
cmp x20, 4
blt clamping
inner_loop_tail:\n'''
inner_loop_label = 'inner_loop_tail'

for mr in range(0, M):
for l in self.base_input_asm()['loop']:
asm_string += l.format(
AM_ptr=self.am_registers()[mr],
AM=self.a_registers(mr),
a_offset=self.k_register(),
)
asm_string += self.do_loop(M, N, 0)
# loop counter
asm_string += self.cmp_k_and_jump_if_less(label=inner_loop_label, decrement=4, cond='bne')
asm_string += '\n'

return asm_string

def outer_loop_prepare(self, M, N):
return ''

Expand Down Expand Up @@ -276,31 +320,30 @@ def clamp_inputs_and_outputs(
def increment_ptr(self, ptr, step):
return f'add {ptr}, {ptr}, {step}\n'

def zero_gp_register(self, reg):
return f'eor {reg}, {reg}, {reg}\n'
def initialize_k_register(self, reg):
kc_register = self.kc_register()
return f'mov {reg}, {kc_register}\n'

def cmp_k_and_jump_if_less(self, label):
def cmp_k_and_jump_if_less(self, label, decrement, cond):
kc_register = self.kc_register()
k_register = self.k_register()
return """add {k_register}, {k_register}, 4
cmp {kc_register}, {k_register}
bne {label}\n""".format(
label=label, k_register=k_register, kc_register=kc_register
)
return """subs {k_register}, {k_register}, {decrement}
{cond} {label}\n"""

def epilogue(self, M, N, isa):
restore_stack = """
return:
# Restore the callee saved GP registers.
ldp x19, x20, [sp, -48]
ldp x21, x22, [sp, -32]
ldp x23, x24, [sp, -16]
ldp x19, x20, [sp, -64]
ldp x21, x22, [sp, -48]
ldp x23, x24, [sp, -32]
ldp x25, x26, [sp, -16]
# Restore callee saved q8-q15 registers.
ldp q8, q9, [sp, -176]
ldp q10, q11, [sp, -144]
ldp q12, q13, [sp, -112]
ldp q14, q15, [sp, -80]
ldp d8, d9, [sp, -128]
ldp d10, d11, [sp, -112]
ldp d12, d13, [sp, -96]
ldp d14, d15, [sp, -80]
ret
END_FUNCTION {function_name}""".format(
M=M, N=N, function_name=isa.function_name(M, N, isa.isa())
Expand Down
4 changes: 2 additions & 2 deletions gemm_compiler/avx512f_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def store(
if pop_c:
asm_string += '\n' + '# Pop output pointers from the stack.\n'
c_reg_offset = 0
POP_C = 'mov {C_REG}, [rsp + {offset}]\n'
POP_C = 'mov {C_REG}, [rsp - {offset}]\n'
for mr in range(0, M):
sp_offset = 128 + (mr) * 16 + 8
asm_string += POP_C.format(C_REG=cm_registers[mr], offset=sp_offset)
Expand All @@ -208,7 +208,7 @@ def store(
)
if pop_c:
asm_string += '\n' + '# Write output pointers to the stack.\n'
POP_C = 'mov [rsp + {offset}], {C_REG}\n'
POP_C = 'mov [rsp - {offset}], {C_REG}\n'
for mr in range(0, M):
sp_offset = 128 + (mr) * 16 + 8
asm_string += POP_C.format(C_REG=cm_registers[mr], offset=sp_offset)
Expand Down
4 changes: 2 additions & 2 deletions gemm_compiler/base_architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,8 @@ def increment_ptr(self, ptr, step):
raise NotImplementedError

@abstractmethod
def zero_gp_register(self, reg):
"""Zero the given general purpose register."""
def initialize_k_register(self, reg):
"""Initialized the given general purpose register for inner loop control."""
raise NotImplementedError

@abstractmethod
Expand Down
8 changes: 3 additions & 5 deletions gemm_compiler/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def generate_gemm_microkernel(

# the outer loop label
asm_string += '\nouter_loop:\n'
asm_string += '# Zero k counter.\n'
asm_string += isa.zero_gp_register(k_register)
asm_string += '# Initialize k counter.\n'
asm_string += isa.initialize_k_register(k_register)

# Read a registers from the stack if required
asm_string += isa.read_a_registers(M=M)
Expand All @@ -54,13 +54,11 @@ def generate_gemm_microkernel(
# inner loop
asm_string += isa.inner_loop(M, N)

# loop counter
asm_string += isa.cmp_k_and_jump_if_less(label='inner_loop')

asm_string += isa.dequantize(M=M, N=num_horizontal_registers, W=w_ptr_reg)

# min/max clamping
asm_string += '# Min/max clamping..\n'
asm_string += 'clamping:\n'
for nr in range(0, num_horizontal_registers):
for mr in range(0, M):
asm_string += isa.clamp_min(
Expand Down
22 changes: 18 additions & 4 deletions gemm_compiler/generate_f32_gemm_microkernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,28 @@ def generate_f32_gemm_microkernels():
),
)

for nr in range(8, 17, 8):
for unroll in {1, 2, 4}:
decrement = 32 * unroll
for mr in range(1, 6):
generate.generate_gemm_microkernel(
M=mr,
N=nr,
isa=neonfma_template.NeonFma(),
N=16,
isa=neonfma_template.NeonFmaUnolled(unroll),
output_file=os.path.join(
output_base,
f'f32-gemm-{mr}x16-minmax-asm-aarch64-neonfma-ld{decrement}.S',
),
)

for unroll in {1, 2, 4}:
decrement = 32 * unroll
for mr in range(1, 9):
generate.generate_gemm_microkernel(
M=mr,
N=8,
isa=neonfma_template.NeonFmaUnolled(unroll),
output_file=os.path.join(
output_base,
f'f32-gemm-{mr}x{nr}-minmax-asm-aarch64-neonfma-ld32.S',
f'f32-gemm-{mr}x8-minmax-asm-aarch64-neonfma-ld{decrement}-2.S',
),
)
20 changes: 0 additions & 20 deletions gemm_compiler/neondot_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,26 +74,6 @@ def quantization_params(self):
def quantization_params_register(self):
return 'x24'

def input_asm(self):
in_asm = {
'loop': [
'ldr d{AM}, [{AM_ptr}, {a_offset}]\n',
]
}
return in_asm

def weights_asm(self):
w_asm = {
'loop': [
'ldr q{W}, [{W_ptr}, {offset}]\n',
],
'loop_2': [
'ldp q{W}, q{W_1}, [{W_ptr}, {offset}]\n',
],
'after': 'add {W}, {W}, {w_step}\n',
}
return w_asm

def compute_asm(self):
c_asm = {
'loop': ['sdot v{ACC}.4s, v{W}.16b, v{A}.4b[0]\n'],
Expand Down
Loading

0 comments on commit 2428c72

Please sign in to comment.