-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove permutation of KV cache #398
base: main
Are you sure you want to change the base?
Conversation
This PR introduces a memory layout attribute for the MemoryType that allows the user to specify a physical shape that differs from the logical shape. This is useful in scenarios like kv-caches where certain dimensions physically are quite large but map to fixed logical dimensions. Signed-off-by: Harsh Menon <harsh@nod-labs.com>
d84210c
to
f516332
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add a test that demonstrates usage of this? It looks like it's None
in all the lit tests still
lit_tests/kernel/wave/expansion.py
Outdated
@@ -309,21 +309,21 @@ def test_gemm(): | |||
# CHECK-SAME: (%b, 4, None, (), None) | |||
|
|||
# CHECK-NEXT: %mma_M:0_N:0_K:0 | |||
# CHECK-SAME: (%read_M:0_N:0_K:0, %read_M:0_N:0_K:0, %acc_M:0_N:0_K:0, None) | |||
# CHECK-SAME: (%read_M:0_N:0_K:0, %read_M:0_N:0_K:0, %acc_M:0_N:0_K:0, None, None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if set_node_indices
(or an earlier pass) actually wrote the MMA dimensional mapping to the node? Disadvantage: lit tests would be more verbose. Advantage: intermediate trace printing would make it much clearer what was going on. I was actually thinking that having all node attributes (so including vector shapes and indicies) attached to the nodes would be helpful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't the custom printing print the indices? (I don't think it print the vector shapes but we can add that as well)
Signed-off-by: Harsh Menon <harsh@nod-labs.com>
Turns out I didnt have to introduce a new dimensional mapping construct, this is already handled by mapping. |
f516332
to
d2fa3b3
Compare
This PR removes the permutation of the key and value cache prior to kernel invocation. Signed-off-by: Harsh Menon <harsh@nod-labs.com>
d2fa3b3
to
f5d1c07
Compare
This PR removes the permutation of the key and
value cache prior to kernel invocation. In order to accomplish this, we add an additional attribute to the MMA operator that specifies which dimensions map to the M, N, K dimensions of the matrix multiplication.