-
Notifications
You must be signed in to change notification settings - Fork 38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Support TorchScript and graph rewrite #54
base: main
Are you sure you want to change the base?
Conversation
Good to see this PR!
|
The PR #43 (@NickNickGo) is for fairseq, and this PR currently only works for transformers-bart. There will be no conflicts between these two PRs. [Jiusheng]: is it possible to cover both fairseq and transformers?
There is a performance issue after the graph is rewritten. Once the issue is resolved, I will update the benchmark numbers and add docs. |
SelfAttention, _reorder_buffer) | ||
|
||
from fastseq.logging import get_logger | ||
from fastseq.utils.api_decorator import replace | ||
from fastseq.optimizer.jit.graph_rewriter import optimize_graph |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i see this is only for bart, with git, we should be able to optimize multiple model from backend?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the optimization can be applied to other models. The current limitation is that we need to check if other models are compatible with torch.jit.script
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes sense. Please include these in the PR. Look forward to review them.
@graph_pattern | ||
def einsum_rewrite_pattern_0(eqn: str, operands: List[Tensor]): | ||
# for cases like "bmhtd,bnhsd->bmhts" | ||
if (len(eqn) == 18 and eqn[0:3] == eqn[13:16] and eqn[0] == eqn[6] and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
eqn[0:4] == eqn[13:17]?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
space is allowed in equation, replace them first
eqn = eqn.replace(' ', '')
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, space is allowed here. One issue I'm working on is that adding replace triggers some weird issue in IRParser.
return r | ||
|
||
# for cases like "bmhts,bnhsd->bmhtd" | ||
if (len(eqn) == 18 and eqn[0:3] == eqn[13:16] and eqn[0] == eqn[6] and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
def einsum_rewrite_pattern_0(eqn: str, operands: List[Tensor]): | ||
# for cases like "bmhtd,bnhsd->bmhts" | ||
if (len(eqn) == 18 and eqn[0:3] == eqn[13:16] and eqn[0] == eqn[6] and | ||
eqn[2] == eqn[8] and eqn[4] == eqn[10] and eqn[3] == eqn[16] and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
eqn[3] == eqn[16] is unnecessary if eqn[0:4] == eqn[13:17] used
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
def test_einsum_rewriter(self): | ||
|
||
def run_einsum(t0: Tensor, t1: Tensor): | ||
r = torch.einsum("bmhtd,bnhsd->bmhts", t0, t1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- add some extra spaces in equation
- use a different char set like i, j, k etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Item-2 is done.
cf38541
to
0575353
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @feihugis for this PR! Looks good in general.
- Looking forward to speedup / Profile comparison for einsum op before/after .
- Can more cases/shapes be covered under "einsum_rewrite_pattern_0" function?
- Could you briefly describe changes to make Self Attention of transformers-bart model compatible with JIT? Maybe add few comments in code.
def einsum_rewrite_pattern_0(eqn: str, operands: List[Tensor]): | ||
# eqn = eqn.replace(' ', '') # TODO: fix the issue: ValueError: stoll | ||
# for cases like "bmhtd,bnhsd->bmhts" | ||
if (len(eqn) == 18 and eqn[0:4] == eqn[13:17] and eqn[0] == eqn[6] and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make this more general ? Same pattern can be used for equations without batch dimension.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer to leaving it as what it is. If we meet the cases in the future, it can be added easily with similar code block. To make it more general, it will be more like the implementation of einsum kernel.
From the micro benchmarking result, the runtime for large tensors will be very similar with/without the optimization.
|
||
# for cases like "bmhts,bnhsd->bmhtd" | ||
if (len(eqn) == 18 and eqn[0:4] == eqn[13:17] and eqn[0] == eqn[6] and | ||
eqn[2] == eqn[8] and eqn[4] == eqn[9] and eqn[10] == eqn[17]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same.
t0 = t0.reshape(b*h, m*t, s) | ||
t1 = t1.view(b*h, s, d) | ||
r = torch.bmm(t0, t1).view(b, h, m, t, d).permute(0, 2, 1, 3, 4) | ||
return r |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is returned tensor contiguous ? When comparing speedup with einsum, please take this into account as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, the returned tensor is not contiguous, and the output of einsum
is not contiguous either, so I think it is an apples to apples comparison.
time1 = timeit.Timer( | ||
functools.partial(script_run_einsum, eqn, t0, t1)) | ||
s1 = time1.timeit(repeat_times) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is cuda synchronization taken care ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. torch.cuda.synchronize()
is added now.
c1955b7
to
5acea73
Compare
The major change is to make the code work with the limited data types that torchscript supports and handle the different behaviors between python and torchscript. For example, python can update the values of the dictionary in place, but torchscript could not. In order to handle these differences, the code logic is changed accordingly. |
Based on the below perf benchmarking, the performance with/without optimization are very similar. Micro benchmark for optimized operation:
E2E benchmark results:
|
* Fix prophenet dict loading. * Use logger. * Fix import.
* Generate the XML log file for each unit tests * run all fastseq unit tests * Add Nikhil's changes on pipeline to publish XML * Just use a small unit test to test pipeline * Change the xml folder path * Add more tests * Add env var for xml log dir and test the failures * Enable all fastseq unit tests * Enable all tests * Generate xml files for fairseq and transformers unit tests * Fix an issue in pytest command * Trigger the CI pipeline
… (#59) * Update install_requires and enable fairseq to work with torch 1.6&1.7 * Better error message and address some warnings in torch1.7 * Raise the error if fairseq/transformers are installed but the optmizations can not be applied * Move transformers/fairseq to extra_require * Remove the out-of-dated build files for ngram cuda op * Run fastseq units before transformers and fairseq
This PR enables SelftAttention of transformers-bart model to be compatible with TorchScript and add the graph rewriter/optimization for einsum.