-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Proton][Dialect] Add Proton Device Memory Buffer Init and Allocate Pass #5606
base: proton-dev
Are you sure you want to change the base?
Conversation
708cb2c
to
73d584d
Compare
@@ -256,6 +256,9 @@ def make_ttgir(mod, metadata, options): | |||
passes.common.add_canonicalizer(pm) | |||
passes.common.add_cse(pm) | |||
passes.common.add_symbol_dce(pm) | |||
|
|||
proton.passes.ttgpuir.add_allocate_device_buffer(pm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need a separate pass for adding those ops.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the alternative? Just allocating the device buffer allocation in the same pass as the init op insertion? Or get rid of the TTGIR init op altogether and just do this entirely on the LLVM IR level?
I'm somewhat worried about hiding the details of how all this is done behind the scenes from developers. So my thinking was if we have an "init" op at the TTGIR level that maps to the device allocation buffer on the LLVM IR things would at least have a 1:1 mapping. But I think that needs to be done in two passes then - one to add the init op if a record op is found and then another to rewrite the init op into the LLVM level device address space buffer.
@@ -231,6 +232,9 @@ struct ConvertTritonAMDGPUToLLVM | |||
mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns, | |||
targetInfo, commonBenefit); | |||
|
|||
mlir::triton::proton::populateInitDeviceBufferOpToLLVMPattern( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't couple our proton LLVM lowering logic into other backends (nvidia and amd) llvm lowering rules. Instead, we are going to have a separate proton conversion pass which populates all kinds of proton op lowering patterns.
@@ -267,6 +267,7 @@ def make_ttgir(mod, metadata, opt, capability): | |||
nvidia.passes.ttnvgpuir.add_fence_insertion(pm) | |||
nvidia.passes.ttnvgpuir.add_tma_lowering(pm) | |||
passes.common.add_canonicalizer(pm) | |||
proton.passes.ttgpuir.add_allocate_device_buffer(pm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here, we don't want this to be a separate pass
@@ -153,6 +153,9 @@ struct ConvertTritonGPUToLLVM | |||
targetInfo, benefit); | |||
mlir::triton::proton::populateRecordOpToLLVMPattern(typeConverter, patterns, | |||
targetInfo, benefit); | |||
mlir::triton::proton::populateInitDeviceBufferOpToLLVMPattern( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as amd's comment
@@ -62,4 +64,15 @@ def TT_RecordOp : TT_Proton_Op<"record", [DeclareOpInterfaceMethods<MemoryEffect | |||
let assemblyFormat = " `(` operands `)` attr-dict"; | |||
} | |||
|
|||
|
|||
def TT_InitDeviceBufferOp : TT_Proton_Op<"init_device_buffer", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you rename it to something like buffer_alloc
(consistent with local_alloc
)? It is better to have the return type of memory descriptor. Then we could use smem and device_buffer in different RecordOp strategies interchangeably.
@@ -62,4 +64,15 @@ def TT_RecordOp : TT_Proton_Op<"record", [DeclareOpInterfaceMethods<MemoryEffect | |||
let assemblyFormat = " `(` operands `)` attr-dict"; | |||
} | |||
|
|||
|
|||
def TT_InitDeviceBufferOp : TT_Proton_Op<"init_device_buffer", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is also a good idea to have its dual operation: buffer_dealloc. Not necessary to be in the same PR but good to think about it as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that is actually possible in this case. We're allocating a compile time, fixed size, stack buffer in device memory. I don't know how we would dealloc that. Keep in mind this is different from the shared memory buffer which is dynamically allocated.
@@ -10,6 +10,11 @@ void populateRecordOpToLLVMPattern(LLVMTypeConverter &typeConverter, | |||
RewritePatternSet &patterns, | |||
const TargetInfoBase &targetInfo, | |||
PatternBenefit benefit); | |||
void populateInitDeviceBufferOpToLLVMPattern(LLVMTypeConverter &typeConverter, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After we have our own llvm lowering conversion pass, let's this move to proton/dialect/lib/...
third_party/proton/dialect/lib/TritonProtonToLLVM/InitDeviceBufferOpToLLVM.cpp
Show resolved
Hide resolved
|
||
|
||
@triton.jit | ||
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the end-to-end testing, you could manually construct a TTGIR with the buffer_alloc_op and read write to it and finally write it back to gmem to check its value in python.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, I think we'll want to go through in another PR and add all the end to end testing at once to make sure we have the code coverage we want.
Add the init and allocation of the Proton dialect device buffer that can be used in place of the shared memory buffer. The device buffer is just a module local, zero initialized, stack buffer in address space(1).