Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve fusions with dequantizelinear #3551

Closed
causten opened this issue Oct 24, 2024 · 8 comments
Closed

Improve fusions with dequantizelinear #3551

causten opened this issue Oct 24, 2024 · 8 comments
Assignees
Labels
bug Something isn't working high priority A PR with high priority for review and merging.

Comments

@causten
Copy link
Collaborator

causten commented Oct 24, 2024

dequantizelinear_kernel on the llama2 7B 16a4w model is not fusing as expected

@causten causten added bug Something isn't working high priority A PR with high priority for review and merging. labels Oct 24, 2024
@turneram
Copy link
Contributor

I see two patterns in llama2-int4 that are not getting fused.
unpack_int4->dequantize_linear->reshape->transpose->dot
and
unpack_int4->unsqueeze->multibroadcast->contiguous->reshape_lazy->dequantize_linear->reshape->transpose->dot

@pfultz2
Copy link
Collaborator

pfultz2 commented Oct 24, 2024

Are you able to show the operators with the shapes to understand how its reshaping and broadcasting?

@lakhinderwalia
Copy link
Contributor

A test case can show the matcher isn't being hit for unpack_int4:

bin/driver compile ../test/onnx/matmulnbits_mm2_test.onnx.

module: "main"
@0 = check_context::migraphx::gpu::context -> float_type, {}, {}
@1 = hip::hip_allocate_memory[shape=int8_type, {752}, {1},id=main:scratch] -> int8_type, {752}, {1}
@2 = load[offset=0,end=384](@1) -> float_type, {2, 3, 16}, {48, 16, 1}
scales = @param:scales -> float_type, {6}, {1}
@4 = reshape_lazy[dims={2, 3, 1}](scales) -> float_type, {2, 3, 1}, {3, 1, 1}
@5 = multibroadcast[out_lens={2, 3, 16},out_dyn_dims={}](@4) -> float_type, {2, 3, 16}, {3, 1, 0}
@6 = gpu::code_object[code_object=5016,symbol_name=contiguous_kernel,global=48,local=1024,](@5,@2) -> float_type, {2, 3, 16}, {48, 16, 1}
@7 = load[offset=656,end=752](@1) -> uint8_type, {2, 48}, {48, 1}
b = @param:b -> uint8_type, {2, 3, 8}, {24, 8, 1}
@9 = reshape_lazy[dims={2, -1}](b) -> uint8_type, {2, 24}, {24, 1}
@10 = gpu::code_object[code_object=5016,symbol_name=unpack_int4_kernel,global=48,local=1024,](@9,@7) -> uint8_type, {2, 48}, {48, 1}
@11 = reshape_lazy[dims={2, 48}](@6) -> float_type, {2, 48}, {48, 1}
@12 = load[offset=384,end=648](@1) -> float_type, {2, 33}, {33, 1}
@13 = slice[axes={1},starts={0},ends={33}](@11) -> float_type, {2, 33}, {48, 1}
@14 = slice[axes={1},starts={0},ends={33}](@10) -> uint8_type, {2, 33}, {48, 1}
@15 = gpu::code_object[code_object=5088,symbol_name=dequantizelinear_kernel,global=66,local=1024,](@14,@13,@12) -> float_type, {2, 33}, {33, 1}
main:#output_0 = @param:main:#output_0 -> float_type, {2, 2}, {2, 1}
a = @param:a -> float_type, {2, 33}, {33, 1}
@18 = gpu::code_object[code_object=5216,symbol_name=mlir_transpose_dot,global=256,local=256,](a,@15,main:#output_0) -> float_type, {2, 2}, {2, 1}
@19 = @return(@18)

@lakhinderwalia
Copy link
Contributor

  1. We have an issue with the type of zero_point: when it is uint8, there is no easy way to fuse within the current MLIR supported types.
  2. Addtionally, a bug is being fixed in parsing matmulnbits. That bug won't fix the above problem, however.

@lakhinderwalia
Copy link
Contributor

lakhinderwalia commented Oct 28, 2024

#3566

With this PR, it will properly fuse in this new test-case:
bin/driver compile ../test/onnx/matmulnbits_mm2_signed_test.onnx.

But this uint8 case is going to still be unfixed:
bin/driver compile ../test/onnx/matmulnbits_mm2_test.onnx

@lakhinderwalia
Copy link
Contributor

This PR addresses this issue: #3609.

@pfultz2
Copy link
Collaborator

pfultz2 commented Nov 12, 2024

It does enable some fusions, but we still are not fusing all the dqs in the model. More investigation is needed.

@lakhinderwalia
Copy link
Contributor

More relevant PRs: #3645, #3629, #3632, #3637, #3582

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working high priority A PR with high priority for review and merging.
Projects
None yet
Development

No branches or pull requests

4 participants