-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fuse batch normalization into convolution kernel #2629
base: main
Are you sure you want to change the base?
Conversation
Regarding terminology, what is preferred in StableHLO for convolution rhs: |
Is this to say - during training these values won't be We've discussed before that we'll need a way to adjust the knobs in terms of what patterns get applied, and that's a problem I plan to take on early next year. In the meantime, probably fine to have this pattern in this pass. If we decided it wasn't desirable on the default path, we can always make this it's own pass.
cc @ghpvnist regarding the terminology question, any preference from a spec perspective? |
I like |
Yes, I assume that’s why there are several operations like |
This fell off my radar a few weeks back - That all makes sense! Pattern LGTM if we can make the test file more targeted / shorter! |
Sorry for the lack of updates, been a bit swamped lately. Not sure how to make test shorter. Started by taking the kernel/weight from the first layer of the ResNet model (probably resnet18) in ONNX as my expected data. Then I took a random picture and ran it through the ONNX Runtime, compiled with debug flags, to dump the input and output data from that layer for the current test case. The goal is to see if the results from fused operators and the simplified batch normalization operations (according to the spec) match up. The problem is that the interpreter is running slower than I expected, so I cut down the input, expected output, and weights data (using stablehlo.slice and applying folding patterns to preserve the initial idea) to make it less CPU-intensive. But it’s still too slow I think I can trim it down even more. Also, I believe this requires a few tests to check which convolution configurations are currently supported. |
I need to figure out why bazel builds are so much slower than cmake..this test only took a few seconds on cmake. At a bare minimum I'll figure out a way to tag tests as large and not run the bazel CI for them. I didn't notice that this test was in I'm thinking about unit tests, i.e.
|
Made the following PR which lets |
Yes, i agree that it should be fine to use dummy data as we interested in transformations.
Cool. I`ll try to finish up this pull request.. |
da6cff5
to
b2551e2
Compare
This introduces a simplification that merges the batch normalization inference operation with convolution kernel (a.k.a. weight). The key idea is that the batch normalization parameters change during the training phase, but remain constant during inference. This means that the convolution kernel can be adjusted to incorporate the effects of batch normalization. This optimization is applied by default to the ResNet model in the ONNX framework.
It performs the following transformation:
into
using following calculations:
Similar optimization can be found in PyTorch:
https://github.com/pytorch/pytorch/blob/main/torch/nn/utils/fusion.py#L56