-
Notifications
You must be signed in to change notification settings - Fork 226
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Major] Fuse bias+gemm and layernorm+quantization for more efficient ViT #254
base: main
Are you sure you want to change the base?
Conversation
int hidden_size = input.size(-1); | ||
int num_tokens = input.numel() / hidden_size; | ||
dim3 grid(num_tokens); | ||
dim3 block(std::min(hidden_size, 1024)); | ||
dim3 block(std::min(hidden_size/2, 1024));//Prevent thread idling when the embedding size is greater than 1024 and not an integer multiple of it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we fix this in a more elegant way to improve utilzation?
@@ -41,23 +40,24 @@ __inline__ __device__ Tf compute_layernorm(Tf val, float s_mean, float s_varianc | |||
* First pass (loop) computes the mean. | |||
* Second computes the variance via Var[x] = E[(x - E[x])²]. | |||
* Third pass computes and writes normed_output | |||
* | |||
* with USE_DIFF_OF_SQUARES set to true (may be faster but less accurate): | |||
* For better speedup, we set USE_DIFF_OF_SQUARES to true (may be faster but less accurate): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets keep the original template for better flexibility.
#pragma unroll | ||
for (int i = 0; i < CTA_N ; i++) | ||
{ | ||
Bias_shared[i] = __half2float(Bias[cta_offset_n+i]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the bottleneck?
@@ -592,10 +606,10 @@ void w8a8_gemm_fuse_bias_forward_cuda(torch::Tensor _in_feats, | |||
constexpr int CTA_M = 128; | |||
constexpr int CTA_N = 128; | |||
constexpr int CTA_K = 64; | |||
constexpr int WARP_M = 128; | |||
constexpr int WARP_M = 64; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Issue] Why couldn't we have 128 here?
@@ -604,7 +618,7 @@ void w8a8_gemm_fuse_bias_forward_cuda(torch::Tensor _in_feats, | |||
constexpr int CTA_N = 64; | |||
constexpr int CTA_K = 64; | |||
constexpr int WARP_M = 32; | |||
constexpr int WARP_N = 32; | |||
constexpr int WARP_N = 16; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Issue] Why couldn't we have 16 here?
@ys-2020 I have fused bias+gemm and layernorm+quantization. These optimizations achieve 1.5-1.6x and 1.5-2.0x kernel speedup respectively on RTX 4090, which lead to about 1.15x e2e speedup for ViT. I have also made some optimizations on these for Orin, but there is still room for improvement. I will complete it later.