-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Proton][Dialect] Middle-end support of the Proton Dialect and the frontend Python package #5677
base: proton-dev
Are you sure you want to change the base?
Conversation
…ors and a lowering pass). 2. frontend Python package for end-users
#include "Transforms/Passes.h.inc" | ||
|
||
namespace { | ||
int maxPowerof2(unsigned int n) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you trying to implement llvm::NextPowerOf2
?
@@ -518,6 +519,7 @@ def __call__(self, gridX, gridY, gridZ, stream, function, *args): | |||
grid_size = gridX * gridY * gridZ | |||
alloc_size = grid_size * self.global_scratch_size | |||
global_scratch = _allocation._allocator(alloc_size, self.global_scratch_align, stream) | |||
set_alloc_state(global_scratch, grid_size, self.global_scratch_size, self.global_scratch_align, stream) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it might be better to just add a new profile_buffer
argument to the launch function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This way global_scratch
can be used as it was
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And we can profile kernels with both global_scratch
and profile_buffer
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My initial thought was actually something like
def alloc_fn():
size = size + profile_buffer_size
ret = torch.empty(size, device="cuda", dtype=torch.int8)
libproton.set_profile_buffer(original pointer of ret + size)
But still it's not as clean as using an independent buffer
let assemblyFormat = " `(` operands `)` attr-dict"; | ||
} | ||
|
||
def PT_CircularRecordOp : PT_Op<"circular_record", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: it doesn't have to be an interface, just traits like MemoryEffects<[MemRead<GlobalMemory>]>]
is fine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit confused here for the name, why do we need an additional op? Cannot we just use RecordOp
since it also has a strategy field?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My thought would be the RecordOp
gets lowered a strategy-specific RecordOp
like CircularRecordOp. In the backend, we would have different strategies (e.g., use a circular buffer, periodic flush, sampling, etc.). These Op would have different runtime arguments and have a separate LLVM lowering.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there any fields in CircularRecordOp
different from RecordOp
? If not, let's keep it simple to use RecordOp
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CircularRecordOp has those inputs like memdesc
and index ptr
. It is different than RecordOp
, which doesn't suppose to be aware of these. There might be other "backend" RecordOp
(for example PeriodicFlushRecordOp
or SamplingRecordOp
). If we have a strategy that periodically flush the SMEM buffer, we need other arguments like proton global buffer's size, index, alignment. If it is a sampling strategy, this OP would have some arguments about sampling state and attributes.
let assemblyFormat = " `(` operands `)` attr-dict"; | ||
} | ||
|
||
def PT_CircularRecordOp : PT_Op<"circular_record", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> { | ||
let summary = "Record a GPU hardware event into a circular buffer"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: it's not always a "hardware" event, we can record any even into the circular buffer, including those analyzed purely from static analysis
let summary = "Record a GPU hardware event into a circular buffer"; | ||
|
||
let description = [{ | ||
The intra kernel profiler records an event into a circular buffer backed by the shared memory `data`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: Maybe unify the use of event
and metric
let assemblyFormat = [{$data `,` $indexPtr attr-dict `:` qualified(type($data)) `,` type($indexPtr)}]; | ||
} | ||
|
||
def PT_FinalizeOp : PT_Op<"finalize", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto for MemoryEffectsOpInterface
let assemblyFormat = [{$data `,` $indexPtr `,` $ptr attr-dict `:` qualified(type($data)) `,` type($indexPtr) `,` type($ptr)}]; | ||
} | ||
|
||
def PT_InitOp : PT_Op<"init", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
|
||
include "mlir/Pass/PassBase.td" | ||
|
||
def ProtonLowering: Pass<"proton-lowering", "mlir::ModuleOp"> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lowering passes should be moved to the Conversion
folder following triton's convention.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In our case, we need to globally add init
and finalize
operators and allocate the memory resources and binds them to the lowered record operators. Conversion works per operator locally and can't go beyond the target operator. A similar case in triton is TritonNvidiaGPUTMALoweringPass
(nvidia.passes.ttnvgpuir.add_tma_lowering
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these conversions have to be decoupled.
We could have one transform pass instrument Init
and Finalize
ops.
One conversion pass that updates the global and shared memory usages.
And other conversion pass that converts all ops to LLVM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The attributes lowering is fine. What about these CircularRecordOp
that takes the result of init
and local_alloc
? We need to rewrite RecordOp
-> CircularReordOp
(that's the case I'm refering). I don't think this can be done in the conversion abstraction.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about the following?
- Remove
granularity
,strategy
, andmetric
fromPT_RecordOp
. Instead, we let the concrete instrumentation pass determines what to measure and how to measure. - Init an instrumentation infrastructure, for now, you just have a single pass that measures the latency.
- Then we have the following passes
- A transform pass that inserts Init and Finalize ops, which is like preparing common stuff for instrumentation
- Individual instrumentation passes that convert
PT_RecordOp
. For our current measurement, we need a pass that convertPT_RecordOp
s toPT_CircularRecordOp
s. 1 and 2 can be combined if we inherit from a template class that includes insertion ofInit
andFinalize
ops. It's just a design issue. - One conversion pass that updates the global and shared memory usages.
- And other conversion pass that converts all ops to LLVM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the shared memory and global memory attributes updating, aren't they get updated in a Transform pass? tt.shared
get sets at AllocateSharedMemory
pass and ttg.global_scratch_memory_size
gets set at TritonGPUGlobalScratchAllocationPass
. They are not conversion pass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait, I probably missed the question.
They are not conversion pass?
AllocateSharedMemory
is indeed a conversion pass isn't it? I'm suggesting that our step 3 can be done either before or after it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant AllocateSharedMemory
is placed under the conversion
directory
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove granularity, strategy, and metric from PT_RecordOp. Instead, we let the concrete instrumentation pass determines what to measure and how to measure.
OK, I understand that it would have the benefit of aligning the existing .enter_scope
and .exit_scope
APIs in proton. In general, that makes sense to me. I finally figured out when you mean conversion pass it is the folder location, not the ConversionRewriter
or OpRewriter
kind of things :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes!
m->setAttr("ttg.shared", | ||
mlir::IntegerAttr::get(mlir::IntegerType::get(context, 32), | ||
maxSharedMem)); | ||
m->setAttr("ttg.global_scratch_memory_size", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once we use a separated buffer, I think these fields do not have to be updated
|
||
include "mlir/Pass/PassBase.td" | ||
|
||
def ProtonLowering: Pass<"proton-lowering", "mlir::ModuleOp"> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these conversions have to be decoupled.
We could have one transform pass instrument Init
and Finalize
ops.
One conversion pass that updates the global and shared memory usages.
And other conversion pass that converts all ops to LLVM
This PR introduces the Proton dialect's mid-end compiler support (operators and a lowering pass), and a frontend Python package for end-users. Namely: