Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Major] Fuse bias+gemm and layernorm+quantization for more efficient ViT #254

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Louym
Copy link
Contributor

@Louym Louym commented Jan 13, 2025

@ys-2020 I have fused bias+gemm and layernorm+quantization. These optimizations achieve 1.5-1.6x and 1.5-2.0x kernel speedup respectively on RTX 4090, which lead to about 1.15x e2e speedup for ViT. I have also made some optimizations on these for Orin, but there is still room for improvement. I will complete it later.

int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
dim3 block(std::min(hidden_size/2, 1024));//Prevent thread idling when the embedding size is greater than 1024 and not an integer multiple of it.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we fix this in a more elegant way to improve utilzation?

@@ -41,23 +40,24 @@ __inline__ __device__ Tf compute_layernorm(Tf val, float s_mean, float s_varianc
* First pass (loop) computes the mean.
* Second computes the variance via Var[x] = E[(x - E[x])²].
* Third pass computes and writes normed_output
*
* with USE_DIFF_OF_SQUARES set to true (may be faster but less accurate):
* For better speedup, we set USE_DIFF_OF_SQUARES to true (may be faster but less accurate):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets keep the original template for better flexibility.

#pragma unroll
for (int i = 0; i < CTA_N ; i++)
{
Bias_shared[i] = __half2float(Bias[cta_offset_n+i]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the bottleneck?

@@ -592,10 +606,10 @@ void w8a8_gemm_fuse_bias_forward_cuda(torch::Tensor _in_feats,
constexpr int CTA_M = 128;
constexpr int CTA_N = 128;
constexpr int CTA_K = 64;
constexpr int WARP_M = 128;
constexpr int WARP_M = 64;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Issue] Why couldn't we have 128 here?

@@ -604,7 +618,7 @@ void w8a8_gemm_fuse_bias_forward_cuda(torch::Tensor _in_feats,
constexpr int CTA_N = 64;
constexpr int CTA_K = 64;
constexpr int WARP_M = 32;
constexpr int WARP_N = 32;
constexpr int WARP_N = 16;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Issue] Why couldn't we have 16 here?

@ys-2020 ys-2020 changed the title [Minor] Fused some kernels [Major] Fuse bias+gemm and layernorm+quantization for more efficient ViT Jan 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants