Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Proton][Dialect] Add Proton Device Memory Buffer Init and Allocate Pass #5606

Open
wants to merge 32 commits into
base: proton-dev
Choose a base branch
from

Conversation

CRobeck
Copy link
Contributor

@CRobeck CRobeck commented Jan 14, 2025

Add the init and allocation of the Proton dialect device buffer that can be used in place of the shared memory buffer. The device buffer is just a module local, zero initialized, stack buffer in address space(1).

@CRobeck CRobeck changed the base branch from main to proton-dev January 23, 2025 03:55
@CRobeck CRobeck changed the title [WIP][Proton][Dialect] Add Initial Infrastructure For Proton Shared Memory Buffer [Proton][Dialect] Add Infrastructure For Proton Device Memory Buffer Jan 23, 2025
@CRobeck CRobeck marked this pull request as ready for review January 23, 2025 22:26
@CRobeck CRobeck requested a review from ptillet as a code owner January 23, 2025 22:26
@CRobeck CRobeck changed the title [Proton][Dialect] Add Infrastructure For Proton Device Memory Buffer [Proton][Dialect] Add Infrastructure For Proton Device Memory Buffer Pass Jan 24, 2025
@CRobeck CRobeck changed the title [Proton][Dialect] Add Infrastructure For Proton Device Memory Buffer Pass [Proton][Dialect] Add Proton Device Memory Buffer Pass Jan 24, 2025
@CRobeck CRobeck changed the title [Proton][Dialect] Add Proton Device Memory Buffer Pass [Proton][Dialect] Add Proton Device Memory Buffer Init and Allocate Pass Jan 25, 2025
@fywkevin fywkevin self-assigned this Jan 25, 2025
@@ -256,6 +256,9 @@ def make_ttgir(mod, metadata, options):
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)

proton.passes.ttgpuir.add_allocate_device_buffer(pm)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need a separate pass for adding those ops.

Copy link
Contributor Author

@CRobeck CRobeck Jan 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the alternative? Just allocating the device buffer allocation in the same pass as the init op insertion? Or get rid of the TTGIR init op altogether and just do this entirely on the LLVM IR level?

I'm somewhat worried about hiding the details of how all this is done behind the scenes from developers. So my thinking was if we have an "init" op at the TTGIR level that maps to the device allocation buffer on the LLVM IR things would at least have a 1:1 mapping. But I think that needs to be done in two passes then - one to add the init op if a record op is found and then another to rewrite the init op into the LLVM level device address space buffer.

lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp Outdated Show resolved Hide resolved
@@ -231,6 +232,9 @@ struct ConvertTritonAMDGPUToLLVM
mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns,
targetInfo, commonBenefit);

mlir::triton::proton::populateInitDeviceBufferOpToLLVMPattern(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't couple our proton LLVM lowering logic into other backends (nvidia and amd) llvm lowering rules. Instead, we are going to have a separate proton conversion pass which populates all kinds of proton op lowering patterns.

@@ -267,6 +267,7 @@ def make_ttgir(mod, metadata, opt, capability):
nvidia.passes.ttnvgpuir.add_fence_insertion(pm)
nvidia.passes.ttnvgpuir.add_tma_lowering(pm)
passes.common.add_canonicalizer(pm)
proton.passes.ttgpuir.add_allocate_device_buffer(pm)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, we don't want this to be a separate pass

@@ -153,6 +153,9 @@ struct ConvertTritonGPUToLLVM
targetInfo, benefit);
mlir::triton::proton::populateRecordOpToLLVMPattern(typeConverter, patterns,
targetInfo, benefit);
mlir::triton::proton::populateInitDeviceBufferOpToLLVMPattern(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as amd's comment

@@ -62,4 +64,15 @@ def TT_RecordOp : TT_Proton_Op<"record", [DeclareOpInterfaceMethods<MemoryEffect
let assemblyFormat = " `(` operands `)` attr-dict";
}


def TT_InitDeviceBufferOp : TT_Proton_Op<"init_device_buffer", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you rename it to something like buffer_alloc (consistent with local_alloc)? It is better to have the return type of memory descriptor. Then we could use smem and device_buffer in different RecordOp strategies interchangeably.

@@ -62,4 +64,15 @@ def TT_RecordOp : TT_Proton_Op<"record", [DeclareOpInterfaceMethods<MemoryEffect
let assemblyFormat = " `(` operands `)` attr-dict";
}


def TT_InitDeviceBufferOp : TT_Proton_Op<"init_device_buffer", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is also a good idea to have its dual operation: buffer_dealloc. Not necessary to be in the same PR but good to think about it as well.

Copy link
Contributor Author

@CRobeck CRobeck Jan 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that is actually possible in this case. We're allocating a compile time, fixed size, stack buffer in device memory. I don't know how we would dealloc that. Keep in mind this is different from the shared memory buffer which is dynamically allocated.

@@ -10,6 +10,11 @@ void populateRecordOpToLLVMPattern(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
const TargetInfoBase &targetInfo,
PatternBenefit benefit);
void populateInitDeviceBufferOpToLLVMPattern(LLVMTypeConverter &typeConverter,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After we have our own llvm lowering conversion pass, let's this move to proton/dialect/lib/...



@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the end-to-end testing, you could manually construct a TTGIR with the buffer_alloc_op and read write to it and finally write it back to gmem to check its value in python.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I think we'll want to go through in another PR and add all the end to end testing at once to make sure we have the code coverage we want.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants