Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Support TorchScript and graph rewrite #54

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

feihugis
Copy link
Contributor

This PR enables SelftAttention of transformers-bart model to be compatible with TorchScript and add the graph rewriter/optimization for einsum.

@JiushengChen
Copy link
Contributor

Good to see this PR!

  1. I suppose PR Encoder-decoder Multihead attention cpu optimization #43 can be deprecated now. Please work with @NickNickGo to merge everything into this PR.
  2. Update benchmarks in scripts and readmes.

@feihugis
Copy link
Contributor Author

feihugis commented Nov 11, 2020

Good to see this PR!

  1. I suppose PR Encoder-decoder Multihead attention cpu optimization #43 can be deprecated now. Please work with @NickNickGo to merge everything into this PR.

The PR #43 (@NickNickGo) is for fairseq, and this PR currently only works for transformers-bart. There will be no conflicts between these two PRs.

[Jiusheng]: is it possible to cover both fairseq and transformers?

  1. Update benchmarks in scripts and readmes.

There is a performance issue after the graph is rewritten. Once the issue is resolved, I will update the benchmark numbers and add docs.

@feihugis feihugis reopened this Nov 12, 2020
SelfAttention, _reorder_buffer)

from fastseq.logging import get_logger
from fastseq.utils.api_decorator import replace
from fastseq.optimizer.jit.graph_rewriter import optimize_graph
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see this is only for bart, with git, we should be able to optimize multiple model from backend?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the optimization can be applied to other models. The current limitation is that we need to check if other models are compatible with torch.jit.script.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense. Please include these in the PR. Look forward to review them.

@graph_pattern
def einsum_rewrite_pattern_0(eqn: str, operands: List[Tensor]):
# for cases like "bmhtd,bnhsd->bmhts"
if (len(eqn) == 18 and eqn[0:3] == eqn[13:16] and eqn[0] == eqn[6] and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eqn[0:4] == eqn[13:17]?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

space is allowed in equation, replace them first

eqn = eqn.replace(' ', '')

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, space is allowed here. One issue I'm working on is that adding replace triggers some weird issue in IRParser.

return r

# for cases like "bmhts,bnhsd->bmhtd"
if (len(eqn) == 18 and eqn[0:3] == eqn[13:16] and eqn[0] == eqn[6] and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

def einsum_rewrite_pattern_0(eqn: str, operands: List[Tensor]):
# for cases like "bmhtd,bnhsd->bmhts"
if (len(eqn) == 18 and eqn[0:3] == eqn[13:16] and eqn[0] == eqn[6] and
eqn[2] == eqn[8] and eqn[4] == eqn[10] and eqn[3] == eqn[16] and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eqn[3] == eqn[16] is unnecessary if eqn[0:4] == eqn[13:17] used

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

def test_einsum_rewriter(self):

def run_einsum(t0: Tensor, t1: Tensor):
r = torch.einsum("bmhtd,bnhsd->bmhts", t0, t1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. add some extra spaces in equation
  2. use a different char set like i, j, k etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Item-2 is done.

Copy link
Contributor

@NickNickGo NickNickGo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @feihugis for this PR! Looks good in general.

  1. Looking forward to speedup / Profile comparison for einsum op before/after .
  2. Can more cases/shapes be covered under "einsum_rewrite_pattern_0" function?
  3. Could you briefly describe changes to make Self Attention of transformers-bart model compatible with JIT? Maybe add few comments in code.

def einsum_rewrite_pattern_0(eqn: str, operands: List[Tensor]):
# eqn = eqn.replace(' ', '') # TODO: fix the issue: ValueError: stoll
# for cases like "bmhtd,bnhsd->bmhts"
if (len(eqn) == 18 and eqn[0:4] == eqn[13:17] and eqn[0] == eqn[6] and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this more general ? Same pattern can be used for equations without batch dimension.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to leaving it as what it is. If we meet the cases in the future, it can be added easily with similar code block. To make it more general, it will be more like the implementation of einsum kernel.

From the micro benchmarking result, the runtime for large tensors will be very similar with/without the optimization.


# for cases like "bmhts,bnhsd->bmhtd"
if (len(eqn) == 18 and eqn[0:4] == eqn[13:17] and eqn[0] == eqn[6] and
eqn[2] == eqn[8] and eqn[4] == eqn[9] and eqn[10] == eqn[17]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same.

t0 = t0.reshape(b*h, m*t, s)
t1 = t1.view(b*h, s, d)
r = torch.bmm(t0, t1).view(b, h, m, t, d).permute(0, 2, 1, 3, 4)
return r
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is returned tensor contiguous ? When comparing speedup with einsum, please take this into account as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the returned tensor is not contiguous, and the output of einsum is not contiguous either, so I think it is an apples to apples comparison.

Comment on lines 55 to 57
time1 = timeit.Timer(
functools.partial(script_run_einsum, eqn, t0, t1))
s1 = time1.timeit(repeat_times)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is cuda synchronization taken care ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. torch.cuda.synchronize() is added now.

@feihugis feihugis force-pushed the dev_torchscript branch 2 times, most recently from c1955b7 to 5acea73 Compare November 17, 2020 06:40
@feihugis
Copy link
Contributor Author

  1. Could you briefly describe changes to make Self Attention of transformers-bart model compatible with JIT? Maybe add few comments in code.

The major change is to make the code work with the limited data types that torchscript supports and handle the different behaviors between python and torchscript. For example, python can update the values of the dictionary in place, but torchscript could not. In order to handle these differences, the code logic is changed accordingly.

@feihugis
Copy link
Contributor Author

Based on the below perf benchmarking, the performance with/without optimization are very similar.

Micro benchmark for optimized operation:

  • eqn='bmhtd,bnhsd->bmhts', shape0=[128, 4, 16, 5, 64], shape1=[128, 2, 16, 1024, 64])

    • einsum took: 3.4279239177703857;
    • optimized einsum torchscript took: 3.422758102416992;
    • optimized einsum python took: 3.422323703765869;
  • eqn='kmijd,knisd->kmijs', shape0=[128, 4, 16, 1, 64], shape1=[128, 2, 16, 1024, 64])

    • einsum took: 3.2339890003204346;
    • optimized einsum torchscript took: 3.231293201446533;
    • optimized einsum python took: 3.2313060760498047;
  • eqn='bmhts,bnhsd->bmhtd', shape0=[128, 4, 16, 5, 64], shape1=[128, 2, 16, 64, 1024])

    • einsum took: 5.048973798751831;
    • optimized einsum torchscript took: 5.0475754737854;
    • optimized einsum python took: 5.050021171569824;
  • eqn='impts,inpsw->imptw', shape0=[128, 4, 16, 3, 64], shape1=[128, 2, 16, 64, 7])

    • einsum took: 0.10066008567810059;
    • optimized einsum torchscript took: 0.08646607398986816;
    • optimized einsum python took: 0.08228182792663574;

E2E benchmark results:

  • with optimization
Util Model Task Split BatchSize Samples Tokens Bleu Rouge Loss Perplexity Runtime(seconds) Throughput(samples/s) Throughput(tokens/s)
transformers_v3.0.2+fastseq_v0.0.4 facebook/bart-large-cnn cnn_dm.1k/raw val 128 1024 NA NA 34.98|14.97|25.28 NA NA 156 6.6 NA
transformers_v3.0.2+fastseq_v0.0.4 facebook/bart-large-cnn cnn_dm.1k/raw val 128 1024 NA NA 34.94|14.95|25.26 NA NA 92 11.1 NA
transformers_v3.0.2+fastseq_v0.0.4 facebook/bart-large-cnn cnn_dm.1k/raw val 128 1024 NA NA 34.96|14.97|25.27 NA NA 92 11.1 NA
transformers_v3.0.2+fastseq_v0.0.4 facebook/bart-large-cnn cnn_dm.1k/raw val 128 1024 NA NA 34.97|14.95|25.27 NA NA 92 11.1 NA
transformers_v3.0.2+fastseq_v0.0.4 facebook/bart-large-cnn cnn_dm.1k/raw val 128 1024 NA NA 34.97|14.92|25.26 NA NA 92 11.1 NA
transformers_v3.0.2+fastseq_v0.0.4 facebook/bart-large-cnn cnn_dm.1k/raw val 128 1024 NA NA 34.98|14.98|25.25 NA NA 91 11.3 NA
transformers_v3.0.2+fastseq_v0.0.4 facebook/bart-large-cnn cnn_dm.1k/raw val 128 1024 NA NA 34.96|14.98|25.28 NA NA 92 11.1 NA
transformers_v3.0.2+fastseq_v0.0.4 facebook/bart-large-cnn cnn_dm.1k/raw val 128 1024 NA NA 34.97|14.94|25.29 NA NA 92 11.1 NA
transformers_v3.0.2+fastseq_v0.0.4 facebook/bart-large-cnn cnn_dm.1k/raw val 128 1024 NA NA 34.98|15.01|25.28 NA NA 92 11.1 NA
transformers_v3.0.2+fastseq_v0.0.4 facebook/bart-large-cnn cnn_dm.1k/raw val 128 1024 NA NA 34.98|14.98|25.26 NA NA 92 11.1 NA
  • without optimization
Util Model Task Split BatchSize Samples Tokens Bleu Rouge Loss Perplexity Runtime(seconds) Throughput(samples/s) Throughput(tokens/s)
transformers_v3.0.2+fastseq_v0.0.4 facebook/bart-large-cnn cnn_dm.1k/raw val 128 1024 NA NA 34.97|14.96|25.27 NA NA 132 7.8 NA
transformers_v3.0.2+fastseq_v0.0.4 facebook/bart-large-cnn cnn_dm.1k/raw val 128 1024 NA NA 34.97|14.95|25.30 NA NA 91 11.3 NA
transformers_v3.0.2+fastseq_v0.0.4 facebook/bart-large-cnn cnn_dm.1k/raw val 128 1024 NA NA 34.98|14.95|25.25 NA NA 92 11.1 NA
transformers_v3.0.2+fastseq_v0.0.4 facebook/bart-large-cnn cnn_dm.1k/raw val 128 1024 NA NA 34.97|14.95|25.30 NA NA 91 11.3 NA
transformers_v3.0.2+fastseq_v0.0.4 facebook/bart-large-cnn cnn_dm.1k/raw val 128 1024 NA NA 34.96|14.93|25.27 NA NA 93 11.0 NA
transformers_v3.0.2+fastseq_v0.0.4 facebook/bart-large-cnn cnn_dm.1k/raw val 128 1024 NA NA 34.97|14.99|25.23 NA NA 91 11.3 NA
transformers_v3.0.2+fastseq_v0.0.4 facebook/bart-large-cnn cnn_dm.1k/raw val 128 1024 NA NA 34.97|14.96|25.25 NA NA 92 11.1 NA
transformers_v3.0.2+fastseq_v0.0.4 facebook/bart-large-cnn cnn_dm.1k/raw val 128 1024 NA NA 34.97|14.94|25.26 NA NA 92 11.1 NA
transformers_v3.0.2+fastseq_v0.0.4 facebook/bart-large-cnn cnn_dm.1k/raw val 128 1024 NA NA 34.97|14.96|25.26 NA NA 92 11.1 NA
transformers_v3.0.2+fastseq_v0.0.4 facebook/bart-large-cnn cnn_dm.1k/raw val 128 1024 NA NA 34.98|14.95|25.26 NA NA 91 11.3 NA

JiushengChen and others added 12 commits November 17, 2020 21:54
* Fix prophenet dict loading.

* Use logger.

* Fix import.
* Generate the XML log file for each unit tests

* run all fastseq unit tests

* Add Nikhil's changes on pipeline to publish XML

* Just use a small unit test to test pipeline

* Change the xml folder path

* Add more tests

* Add env var for xml log dir and test the failures

* Enable all fastseq unit tests

* Enable all tests

* Generate xml files for fairseq and transformers unit tests

* Fix an issue in pytest command

* Trigger the CI pipeline
… (#59)

* Update install_requires and enable fairseq to work with torch 1.6&1.7

* Better error message and address some warnings in torch1.7

* Raise the error if fairseq/transformers are installed but the optmizations can not be applied

* Move transformers/fairseq to extra_require

* Remove the out-of-dated build files for ngram cuda op

* Run fastseq units before transformers and fairseq
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants