Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Track broadcast axes in the shape_transform_descriptor #3610

Merged
merged 28 commits into from
Dec 4, 2024

Conversation

pfultz2
Copy link
Collaborator

@pfultz2 pfultz2 commented Nov 11, 2024

Although, this prevents simplifying as much, it does help preserve the permutation of the broadcasted axes.

So if we have a tensor of {2, 16, 10240} that goes into a reduction along the last axis it will output to {2, 16, 1}, which may be broadcasted back into {2, 16, 10240}, but there could be more shape transformations after the reduce but before an pointwise operator:

@1 = multibroadcast[out_lens={2, 16, 10240},out_dyn_dims={}](@0) -> int64_type, {2, 16, 10240}, {16, 1, 0}
@2 = reshape[dims={2, 160, 32, 32}](@1) -> int64_type, {2, 160, 32, 32}, {163840, 1024, 32, 1}
@3 = transpose[permutation={0, 2, 3, 1}](@2) -> int64_type, {2, 32, 32, 160}, {163840, 32, 1, 1024}

On develop this would be simplified to:

@1 = unsqueeze[axes={1, 2, 5},steps={}](@0) -> int64_type, {2, 1, 1, 16, 1, 1}, {16, 16, 16, 1, 1, 1}
@2 = multibroadcast[out_lens={2, 1, 1, 16, 1, 10},out_dyn_dims={}](@1) -> int64_type, {2, 1, 1, 16, 1, 10}, {16, 16, 16, 1, 1, 0}
@3 = reshape[dims={2, 1, 1, 160}](@2) -> int64_type, {2, 1, 1, 160}, {160, 160, 160, 1}
@4 = multibroadcast[out_lens={2, 32, 32, 160},out_dyn_dims={}](@3) -> int64_type, {2, 32, 32, 160}, {160, 0, 0, 1}

Ideally, we would want to apply these transformations without the broadcast before the reduction but if it simplified like above because the shape_transform_descriptor doesnt track the permutation of the the broadcasted axes. With this PR, it will simplify to:

@1 = unsqueeze[axes={3, 4},steps={}](@0) -> int64_type, {2, 16, 1, 1, 1}, {16, 1, 1, 1, 1}
@2 = transpose[permutation={0, 3, 4, 1, 2}](@1) -> int64_type, {2, 1, 1, 16, 1}, {16, 1, 1, 1, 1}
@3 = multibroadcast[out_lens={2, 1, 1, 16, 10},out_dyn_dims={}](@2) -> int64_type, {2, 1, 1, 16, 10}, {16, 1, 1, 1, 0}
@4 = reshape[dims={2, 1, 1, 160}](@3) -> int64_type, {2, 1, 1, 160}, {160, 160, 160, 1}
@5 = multibroadcast[out_lens={2, 32, 32, 160},out_dyn_dims={}](@4) -> int64_type, {2, 32, 32, 160}, {160, 0, 0, 1}

This has a transpose because the shape_transform_descriptor understands how it will output in NHWC, which means we can make the input to the reduction NHWC layout as well. This PR doesn't enable such rewriting, it only modifies the shape_transform descriptor to track such layouts.

Also, there is some updates to the tests as well:

  • Validate that a simplified transformation produces the same result
  • Check that the simplification cannot be simplified further

@pfultz2 pfultz2 requested a review from causten as a code owner November 11, 2024 19:15
@TedThemistokleous TedThemistokleous added the enhancement New feature or request label Nov 27, 2024
@TedThemistokleous
Copy link
Collaborator

Fix CI but otherwise looks good

Copy link

codecov bot commented Nov 28, 2024

Codecov Report

Attention: Patch coverage is 96.03960% with 4 lines in your changes missing coverage. Please review.

Project coverage is 92.23%. Comparing base (241e24e) to head (64a50f7).
Report is 5 commits behind head on develop.

Files with missing lines Patch % Lines
src/shape_transform_descriptor.cpp 96.03% 4 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #3610      +/-   ##
===========================================
+ Coverage    92.20%   92.23%   +0.02%     
===========================================
  Files          513      513              
  Lines        21658    21726      +68     
===========================================
+ Hits         19970    20038      +68     
  Misses        1688     1688              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@migraphx-bot
Copy link
Collaborator

Test Batch Rate new
64a50f
Rate old
241e24
Diff Compare
torchvision-resnet50 64 3,255.03 3,253.39 0.05%
torchvision-resnet50_fp16 64 6,989.80 6,971.83 0.26%
torchvision-densenet121 32 2,431.22 2,432.57 -0.06%
torchvision-densenet121_fp16 32 4,092.43 4,102.36 -0.24%
torchvision-inceptionv3 32 1,628.97 1,627.63 0.08%
torchvision-inceptionv3_fp16 32 2,750.44 2,741.25 0.34%
cadene-inceptionv4 16 765.03 763.96 0.14%
cadene-resnext64x4 16 810.16 809.72 0.05%
slim-mobilenet 64 7,462.81 7,381.84 1.10%
slim-nasnetalarge 64 208.47 208.40 0.03%
slim-resnet50v2 64 3,435.87 3,439.64 -0.11%
bert-mrpc-onnx 8 1,143.63 1,145.80 -0.19%
bert-mrpc-tf 1 460.49 492.93 -6.58% 🔴
pytorch-examples-wlang-gru 1 527.91 426.82 23.69% 🔆
pytorch-examples-wlang-lstm 1 401.55 392.81 2.23%
torchvision-resnet50_1 1 787.71 777.08 1.37%
cadene-dpn92_1 1 402.67 429.78 -6.31% 🔴
cadene-resnext101_1 1 382.29 383.15 -0.23%
onnx-taau-downsample 1 345.95 346.23 -0.08%
dlrm-criteoterabyte 1 33.33 33.33 -0.01%
dlrm-criteoterabyte_fp16 1 52.76 52.75 0.02%
agentmodel 1 8,359.46 8,083.75 3.41% 🔆
unet_fp16 2 58.91 58.78 0.22%
resnet50v1_fp16 1 1,020.82 943.59 8.18% 🔆
resnet50v1_int8 1 1,008.13 1,004.54 0.36%
bert_base_cased_fp16 64 1,169.80 1,169.61 0.02%
bert_large_uncased_fp16 32 363.53 363.60 -0.02%
bert_large_fp16 1 199.20 200.33 -0.56%
distilgpt2_fp16 16 2,202.65 2,196.55 0.28%
yolov5s 1 540.70 526.38 2.72%
tinyllama 1 43.63 43.64 -0.02%
vicuna-fastchat 1 174.80 173.50 0.75%
whisper-tiny-encoder 1 417.72 417.83 -0.03%
whisper-tiny-decoder 1 427.53 427.29 0.06%

This build is not recommended to merge 🔴

@migraphx-bot
Copy link
Collaborator


     ✅ bert-mrpc-onnx: PASSED: MIGraphX meets tolerance

     ✅ bert-mrpc-tf: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance

     ✅ torchvision-resnet50_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-dpn92_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-resnext101_1: PASSED: MIGraphX meets tolerance

     ✅ dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance

     ✅ agentmodel: PASSED: MIGraphX meets tolerance

     ✅ unet: PASSED: MIGraphX meets tolerance

     ✅ resnet50v1: PASSED: MIGraphX meets tolerance

     ✅ bert_base_cased_fp16: PASSED: MIGraphX meets tolerance

🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output


     ✅ bert_large: PASSED: MIGraphX meets tolerance

     ✅ yolov5s: PASSED: MIGraphX meets tolerance

     ✅ tinyllama: PASSED: MIGraphX meets tolerance

     ✅ vicuna-fastchat: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-encoder: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-decoder: PASSED: MIGraphX meets tolerance

     ✅ distilgpt2_fp16: PASSED: MIGraphX meets tolerance

@CharlieL7 CharlieL7 removed their request for review December 3, 2024 17:00
@pfultz2 pfultz2 merged commit 935b96b into develop Dec 4, 2024
43 of 45 checks passed
@pfultz2 pfultz2 deleted the shape-transform-track-broadcast-axis branch December 4, 2024 17:05
shivadbhavsar pushed a commit that referenced this pull request Dec 18, 2024
Although, this prevents simplifying as much, it does help preserve the permutation of the broadcasted axes.

So if we have a tensor of `{2, 16, 10240}` that goes into a reduction along the last axis it will output to `{2, 16, 1}`, which may be broadcasted back into `{2, 16, 10240}`, but there could be more shape transformations after the reduce but before an pointwise operator:

```
@1 = multibroadcast[out_lens={2, 16, 10240},out_dyn_dims={}](@0) -> int64_type, {2, 16, 10240}, {16, 1, 0}
@2 = reshape[dims={2, 160, 32, 32}](@1) -> int64_type, {2, 160, 32, 32}, {163840, 1024, 32, 1}
@3 = transpose[permutation={0, 2, 3, 1}](@2) -> int64_type, {2, 32, 32, 160}, {163840, 32, 1, 1024}
```

On develop this would be simplified to:

```
@1 = unsqueeze[axes={1, 2, 5},steps={}](@0) -> int64_type, {2, 1, 1, 16, 1, 1}, {16, 16, 16, 1, 1, 1}
@2 = multibroadcast[out_lens={2, 1, 1, 16, 1, 10},out_dyn_dims={}](@1) -> int64_type, {2, 1, 1, 16, 1, 10}, {16, 16, 16, 1, 1, 0}
@3 = reshape[dims={2, 1, 1, 160}](@2) -> int64_type, {2, 1, 1, 160}, {160, 160, 160, 1}
@4 = multibroadcast[out_lens={2, 32, 32, 160},out_dyn_dims={}](@3) -> int64_type, {2, 32, 32, 160}, {160, 0, 0, 1}
```

Ideally, we would want to apply these transformations without the broadcast before the reduction but if it simplified like above because the shape_transform_descriptor doesnt track the permutation of the the broadcasted axes. With this PR, it will simplify to:

```
@1 = unsqueeze[axes={3, 4},steps={}](@0) -> int64_type, {2, 16, 1, 1, 1}, {16, 1, 1, 1, 1}
@2 = transpose[permutation={0, 3, 4, 1, 2}](@1) -> int64_type, {2, 1, 1, 16, 1}, {16, 1, 1, 1, 1}
@3 = multibroadcast[out_lens={2, 1, 1, 16, 10},out_dyn_dims={}](@2) -> int64_type, {2, 1, 1, 16, 10}, {16, 1, 1, 1, 0}
@4 = reshape[dims={2, 1, 1, 160}](@3) -> int64_type, {2, 1, 1, 160}, {160, 160, 160, 1}
@5 = multibroadcast[out_lens={2, 32, 32, 160},out_dyn_dims={}](@4) -> int64_type, {2, 32, 32, 160}, {160, 0, 0, 1}
```

This has a transpose because the shape_transform_descriptor understands how it will output in NHWC, which means we can make the input to the reduction NHWC layout as well. This PR doesn't enable such rewriting, it only modifies the shape_transform descriptor to track such layouts.

Also, there is some updates to the tests as well:

- Validate that a simplified transformation produces the same result
- Check that the simplification cannot be simplified further
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants