This repository has been archived by the owner on Jun 27, 2024. It is now read-only.
[Triton] Block dimension should be inferred from the Triton function #125
Labels
enhancement
New feature or request
Today we hardcode it to 64 here: https://github.com/openxla/openxla-nvgpu/blob/628fc75bcfe00c0ab2551824914a896668f21909/compiler/src/openxla/compiler/nvgpu/Dialect/TritonFlow/Conversion/ConvertTritonToFlowDispatch.cpp#L264-L266
It should be computed from the num-warps attribute, see example in jax: https://github.com/google/jax/blob/e785574b2b7b13f063a5399b13c41c422c54ea14/jaxlib/gpu/triton.cc#LL68C54-L68C54
The text was updated successfully, but these errors were encountered: