Skip to content

Commit

Permalink
Enhance rewrite pattern and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
feihugis committed Nov 14, 2020
1 parent 3ca2eda commit 0575353
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 38 deletions.
49 changes: 21 additions & 28 deletions fastseq/optimizer/jit/einsum_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,61 +10,54 @@

from fastseq.optimizer.jit.utils import graph_pattern, rewrite_graph

@graph_pattern
def einsum_pattern_0(t0: str, t1: List[Tensor]):
r = torch.einsum(t0, t1)
return r

@graph_pattern
def einsum_rewrite_pattern_0(eqn: str, operands: List[Tensor]):
# eqn = eqn.replace(' ', '') # TODO: fix the issue: ValueError: stoll
# for cases like "bmhtd,bnhsd->bmhts"
if (len(eqn) == 18 and eqn[0:3] == eqn[13:16] and eqn[0] == eqn[6] and
eqn[2] == eqn[8] and eqn[4] == eqn[10] and eqn[3] == eqn[16] and
eqn[9] == eqn[17]):
if (len(eqn) == 18 and eqn[0:4] == eqn[13:17] and eqn[0] == eqn[6] and
eqn[2] == eqn[8] and eqn[4] == eqn[10] and eqn[9] == eqn[17]):
t0 = operands[0]
t1 = operands[1]
b = t0.size(0)
m = t0.size(1)
h = t0.size(2)
t = t0.size(3)
d = t0.size(4)
b, m, h, t, d = t0.shape
s = t1.size(3)
n = t1.size(1)
t1 = t1.permute(0, 2, 3, 4, 1) # (b, h, s, d, n)
if n > 1:
t1 = t1.sum(dim=1, keepdim=True) # (b, 1, h, d, s)
s = t1.size(3)
t1 = t1.sum(dim=4, keepdim=True) # (b, h, s, d, 1)

t0 = t0.permute(0, 2, 1, 3, 4) # (b, h, m, t, d)
t1 = t1.permute(0, 2, 1, 4, 3) # (b, h, 1, d, s)
t1 = t1.permute(0, 1, 3, 4, 2) # (b, h, d, 1, s)
t0 = t0.reshape(b*h, m*t, d)
t1 = t1.reshape(b*h, d, s)
t1 = t1.view(b*h, d, s)
r = torch.bmm(t0, t1).view(b, h, m, t, s).permute(0, 2, 1, 3, 4)
return r

# for cases like "bmhts,bnhsd->bmhtd"
if (len(eqn) == 18 and eqn[0:3] == eqn[13:16] and eqn[0] == eqn[6] and
eqn[2] == eqn[8] and eqn[4] == eqn[9] and eqn[3] == eqn[16] and
eqn[10] == eqn[17]):
if (len(eqn) == 18 and eqn[0:4] == eqn[13:17] and eqn[0] == eqn[6] and
eqn[2] == eqn[8] and eqn[4] == eqn[9] and eqn[10] == eqn[17]):
t0 = operands[0]
t1 = operands[1]
b = t0.size(0)
m = t0.size(1)
h = t0.size(2)
t = t0.size(3)
s = t0.size(4)
b, m, h, t, s = t0.shape
n = t1.size(1)
if n > 1:
t1 = t1.sum(dim=1, keepdim=True) # (b, 1, h, s, d)
d = t1.size(4)
t1 = t1.permute(0, 2, 4, 3, 1) # (b, h, d, s, n)
if n > 1:
t1 = t1.sum(dim=4, keepdim=True) # (b, h, d, s, 1)
# t1 = t1.squeeze(1) # (b, h, s, d)
t0 = t0.permute(0, 2, 1, 3, 4) # (b, h, m, t, s)
t1 = t1.permute(0, 2, 1, 3, 4) # (b, h, 1, s, d)
t1 = t1.permute(0, 1, 3, 4, 2) # (b, h, s, 1, d)
t0 = t0.reshape(b*h, m*t, s)
t1 = t1.reshape(b*h, s, d)
t1 = t1.view(b*h, s, d)
r = torch.bmm(t0, t1).view(b, h, m, t, d).permute(0, 2, 1, 3, 4)
return r

return torch.einsum(eqn, operands)

EINSUM_PATTERN_STR = einsum_pattern_0()
EINSUM_REWRITE_PATTERN_STR = einsum_rewrite_pattern_0()
EINSUM_PATTERN_STR = graph_pattern(einsum_pattern_0)()
EINSUM_REWRITE_PATTERN_STR = graph_pattern(einsum_rewrite_pattern_0)()

def rewrite_einsum(input_graph: torch._C.Graph):
rewrite_graph(EINSUM_PATTERN_STR, EINSUM_REWRITE_PATTERN_STR, input_graph)
51 changes: 41 additions & 10 deletions tests/optimizer/jit/test_einsum_rewriter.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,65 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from typing import List
import functools
import logging
import timeit

from absl.testing import absltest
from absl.testing import absltest, parameterized
import torch
from torch import Tensor

from fastseq.logging import get_logger
from fastseq.optimizer.jit.einsum_rewriter import rewrite_einsum
from fastseq.utils.test_utils import TestCaseBase

logger = get_logger(__name__, logging.INFO)

class EinsumRewriterTest(TestCaseBase):

def test_einsum_rewriter(self):
@parameterized.parameters(
{'eqn': "bmhtd,bnhsd->bmhts",
'shape0': [128, 4, 16, 5, 64],
'shape1': [128, 2, 16, 1024, 64]},
{'eqn': "kmijd,knisd->kmijs",
'shape0': [128, 4, 16, 1, 64],
'shape1': [128, 2, 16, 1024, 64]},
{'eqn': "bmhts,bnhsd->bmhtd",
'shape0': [128, 4, 16, 3, 64],
'shape1': [128, 2, 16, 64, 7]},
{'eqn': "impts,inpsw->imptw",
'shape0': [128, 4, 16, 3, 64],
'shape1': [128, 2, 16, 64, 7]},
)
def test_einsum_rewriter(self, eqn, shape0, shape1):

def run_einsum(t0: Tensor, t1: Tensor):
r = torch.einsum("bmhtd,bnhsd->bmhts", t0, t1)
r = r + 2.0
def run_einsum(eqn: str, t0: Tensor, t1: Tensor):
r = torch.einsum(eqn, t0, t1)
return r

t0 = torch.randn(10, 3, 4, 3, 9, dtype=torch.float32)
t1 = torch.randn(10, 1, 4, 7, 9, dtype=torch.float32)
t0 = torch.randn(shape0, dtype=torch.float32).cuda()
t1 = torch.randn(shape1, dtype=torch.float32).cuda()
repeat_times = 1000

r0 = run_einsum(t0, t1)
r0 = run_einsum(eqn, t0, t1)
time0 = timeit.Timer(functools.partial(run_einsum, eqn, t0, t1))
s0 = time0.timeit(repeat_times)

script_run_einsum = torch.jit.script(run_einsum)
logger.debug(f"Original graph: \n{script_run_einsum.graph.str()}")
rewrite_einsum(script_run_einsum.graph)
r1 = script_run_einsum(t0, t1)
logger.debug(f"Optimized graph: \n{script_run_einsum.graph.str()}")
self.assertTrue('bmm' in script_run_einsum.graph.str())

r1 = script_run_einsum(eqn, t0, t1)
time1 = timeit.Timer(
functools.partial(script_run_einsum, eqn, t0, t1))
s1 = time1.timeit(repeat_times)

self.assertTrue(torch.equal(r0, r1))
logger.info(f"einsum took: {s0}; optimized einsum torchscript took: "
f"{s1};")


if __name__ == "__main__":
absltest.main()

0 comments on commit 0575353

Please sign in to comment.