Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Proton][Dialect] Middle-end support of the Proton Dialect and the frontend Python package #5677

Open
wants to merge 11 commits into
base: proton-dev
Choose a base branch
from

Conversation

fywkevin
Copy link
Contributor

This PR introduces the Proton dialect's mid-end compiler support (operators and a lowering pass), and a frontend Python package for end-users. Namely:

  • the middle-end Proton operators that work with the triton_gpu dialect abstraction.
  • the lowering pass leverages the existing global_scratch_memory infrastructure to allocate the GPU memory. We deduce the profiler's buffer size, taking over the freely available SMEM as our backend temporary storage.
  • a Python package for users to interact with the intra kernel profiler and handles memory allocation. It is put in a different namespace (triton.intraprof) to avoid the circular importing problem (triton.profiler depends on triton.compiler).

@fywkevin fywkevin requested review from Jokeren and CRobeck January 23, 2025 03:54
@fywkevin fywkevin requested a review from ptillet as a code owner January 23, 2025 03:54
#include "Transforms/Passes.h.inc"

namespace {
int maxPowerof2(unsigned int n) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you trying to implement llvm::NextPowerOf2?

@@ -518,6 +519,7 @@ def __call__(self, gridX, gridY, gridZ, stream, function, *args):
grid_size = gridX * gridY * gridZ
alloc_size = grid_size * self.global_scratch_size
global_scratch = _allocation._allocator(alloc_size, self.global_scratch_align, stream)
set_alloc_state(global_scratch, grid_size, self.global_scratch_size, self.global_scratch_align, stream)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it might be better to just add a new profile_buffer argument to the launch function

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This way global_scratch can be used as it was

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And we can profile kernels with both global_scratch and profile_buffer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My initial thought was actually something like

def alloc_fn():
    size = size + profile_buffer_size
    ret = torch.empty(size, device="cuda", dtype=torch.int8)
    libproton.set_profile_buffer(original pointer of ret + size)

But still it's not as clean as using an independent buffer

let assemblyFormat = " `(` operands `)` attr-dict";
}

def PT_CircularRecordOp : PT_Op<"circular_record", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: it doesn't have to be an interface, just traits like MemoryEffects<[MemRead<GlobalMemory>]>] is fine

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused here for the name, why do we need an additional op? Cannot we just use RecordOp since it also has a strategy field?

Copy link
Contributor Author

@fywkevin fywkevin Jan 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thought would be the RecordOp gets lowered a strategy-specific RecordOp like CircularRecordOp. In the backend, we would have different strategies (e.g., use a circular buffer, periodic flush, sampling, etc.). These Op would have different runtime arguments and have a separate LLVM lowering.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any fields in CircularRecordOp different from RecordOp? If not, let's keep it simple to use RecordOp.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CircularRecordOp has those inputs like memdesc and index ptr. It is different than RecordOp, which doesn't suppose to be aware of these. There might be other "backend" RecordOp (for example PeriodicFlushRecordOp or SamplingRecordOp). If we have a strategy that periodically flush the SMEM buffer, we need other arguments like proton global buffer's size, index, alignment. If it is a sampling strategy, this OP would have some arguments about sampling state and attributes.

let assemblyFormat = " `(` operands `)` attr-dict";
}

def PT_CircularRecordOp : PT_Op<"circular_record", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "Record a GPU hardware event into a circular buffer";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: it's not always a "hardware" event, we can record any even into the circular buffer, including those analyzed purely from static analysis

let summary = "Record a GPU hardware event into a circular buffer";

let description = [{
The intra kernel profiler records an event into a circular buffer backed by the shared memory `data`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Maybe unify the use of event and metric

let assemblyFormat = [{$data `,` $indexPtr attr-dict `:` qualified(type($data)) `,` type($indexPtr)}];
}

def PT_FinalizeOp : PT_Op<"finalize", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto for MemoryEffectsOpInterface

let assemblyFormat = [{$data `,` $indexPtr `,` $ptr attr-dict `:` qualified(type($data)) `,` type($indexPtr) `,` type($ptr)}];
}

def PT_InitOp : PT_Op<"init", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto


include "mlir/Pass/PassBase.td"

def ProtonLowering: Pass<"proton-lowering", "mlir::ModuleOp"> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lowering passes should be moved to the Conversion folder following triton's convention.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In our case, we need to globally add init and finalize operators and allocate the memory resources and binds them to the lowered record operators. Conversion works per operator locally and can't go beyond the target operator. A similar case in triton is TritonNvidiaGPUTMALoweringPass (nvidia.passes.ttnvgpuir.add_tma_lowering).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these conversions have to be decoupled.

We could have one transform pass instrument Init and Finalize ops.
One conversion pass that updates the global and shared memory usages.
And other conversion pass that converts all ops to LLVM

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The attributes lowering is fine. What about these CircularRecordOp that takes the result of init and local_alloc? We need to rewrite RecordOp -> CircularReordOp (that's the case I'm refering). I don't think this can be done in the conversion abstraction.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about the following?

  • Remove granularity, strategy, and metric from PT_RecordOp. Instead, we let the concrete instrumentation pass determines what to measure and how to measure.
  • Init an instrumentation infrastructure, for now, you just have a single pass that measures the latency.
  • Then we have the following passes
  1. A transform pass that inserts Init and Finalize ops, which is like preparing common stuff for instrumentation
  2. Individual instrumentation passes that convert PT_RecordOp. For our current measurement, we need a pass that convert PT_RecordOps to PT_CircularRecordOps. 1 and 2 can be combined if we inherit from a template class that includes insertion of Init and Finalize ops. It's just a design issue.
  3. One conversion pass that updates the global and shared memory usages.
  4. And other conversion pass that converts all ops to LLVM

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the shared memory and global memory attributes updating, aren't they get updated in a Transform pass? tt.shared get sets at AllocateSharedMemory pass and ttg.global_scratch_memory_size gets set at TritonGPUGlobalScratchAllocationPass. They are not conversion pass?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, I probably missed the question.

They are not conversion pass?

AllocateSharedMemory is indeed a conversion pass isn't it? I'm suggesting that our step 3 can be done either before or after it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant AllocateSharedMemory is placed under the conversion directory

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove granularity, strategy, and metric from PT_RecordOp. Instead, we let the concrete instrumentation pass determines what to measure and how to measure.

OK, I understand that it would have the benefit of aligning the existing .enter_scope and .exit_scope APIs in proton. In general, that makes sense to me. I finally figured out when you mean conversion pass it is the folder location, not the ConversionRewriter or OpRewriter kind of things :-)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes!

m->setAttr("ttg.shared",
mlir::IntegerAttr::get(mlir::IntegerType::get(context, 32),
maxSharedMem));
m->setAttr("ttg.global_scratch_memory_size",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once we use a separated buffer, I think these fields do not have to be updated


include "mlir/Pass/PassBase.td"

def ProtonLowering: Pass<"proton-lowering", "mlir::ModuleOp"> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these conversions have to be decoupled.

We could have one transform pass instrument Init and Finalize ops.
One conversion pass that updates the global and shared memory usages.
And other conversion pass that converts all ops to LLVM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants