Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Proton][Dialect] Middle-end support of the Proton Dialect and the frontend Python package #5677

Draft
wants to merge 11 commits into
base: proton-dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "amd/include/TritonAMDGPUTransforms/Passes.h"
#include "third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h"
#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h"
#include "third_party/proton/dialect/include/Transforms/Passes.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
Expand Down Expand Up @@ -68,6 +69,9 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints();
mlir::triton::registerTritonAMDGPULowerInstructionSchedHints();

// Proton passes
mlir::triton::proton::registerProtonLowering();

// TODO: register Triton & TritonGPU passes
registry
.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
Expand Down
1 change: 1 addition & 0 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,7 @@ def get_packages():
"triton/backends",
"triton/tools",
"triton/tools/extra",
"triton/intraprof",
]
packages += [f'triton/backends/{backend.name}' for backend in backends]
packages += get_extra_packages("language")
Expand Down
2 changes: 2 additions & 0 deletions python/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from . import language
from . import testing
from . import tools
from . import intraprof

__all__ = [
"autotune",
Expand All @@ -48,6 +49,7 @@
"TritonError",
"testing",
"tools",
"intraprof",
]

# -------------------------------------
Expand Down
9 changes: 9 additions & 0 deletions python/triton/intraprof/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .intra import (
init,
finalize,
activate,
deactivate,
set_alloc_state,
)

__all__ = ["init", "finalize", "activate", "deactivate", "set_alloc_state"]
92 changes: 92 additions & 0 deletions python/triton/intraprof/intra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from typing import Optional
from dataclasses import dataclass
import torch
import triton


@dataclass(frozen=True)
class Config:
max_shared: int = 0
alloc_scratch: int = 262144
alignment: int = 128


@dataclass
class State:
grid_size: int = 1
alignment: int = 128
scratch_size: int = 0
config: Config = Config()
stream: Optional[int] = None
profile_mem_cpu: Optional[torch.Tensor] = None


state: Optional[State] = None
global_scratch_mem: Optional[torch.Tensor] = None
activated: bool = False


def set_alloc_state(global_scratch: torch.Tensor, grid_size: int, scratch_size: int, alignment: int,
stream: Optional[int]):
global state
global global_scratch_mem
global activated

if not activated:
return

assert state, "profiler must be initialized"
state.grid_size = grid_size
state.scratch_size = scratch_size
state.alignment = alignment
state.stream = stream
global_scratch_mem = global_scratch


def init(config=dict()):
global state
global activated

if not activated:
return

state = State()
device = triton.runtime.driver.active.get_current_device()
shared_mem = triton.runtime.driver.active.utils.get_device_properties(device)["max_shared_mem"]
args = {'max_shared': shared_mem}
args.update({k: config[k] for k in Config.__dataclass_fields__.keys() if k in config})
state.config = Config(**args)


def finalize() -> Optional[State]:
global state
global global_scratch_mem
global activated

if not activated:
return None

assert state, "profiler must be initialized"
curr_state = state
size = curr_state.grid_size * curr_state.config.alloc_scratch
# TODO(fywkevin): copy profiling data to profile_mem_cpu, the offset depends on the alignment
curr_state.profile_mem_cpu = torch.empty(size, device="cpu", dtype=torch.int8)

state = None
global_scratch_mem = None
return curr_state


def _alloc_fn(size: int, alignment: int, stream: Optional[int]):
return torch.empty(size, device="cuda", dtype=torch.int8)


def activate():
global activated
activated = True
triton.set_allocator(_alloc_fn)


def deactivate():
global activated
activated = False
9 changes: 9 additions & 0 deletions test/Proton/error.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// RUN: triton-opt --split-input-file %s -proton-lowering="max-shared-mem=1024 scratch-mem=512 alignment=128" -verify-diagnostics

module attributes {"ttg.num-warps" = 8 : i32, ttg.shared = 128 : i32} {
// expected-error @+1 {{Global scratch memory for proton is not large enough}}
tt.func @insufficient_global_scratch() {
proton.record() {isStart = true, regionId = 1 : i32}
tt.return
}
} // end module
20 changes: 20 additions & 0 deletions test/Proton/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,23 @@ module {
} // end module

// -----

#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}>
#smem = #ttg.shared_memory
module {
// CHECK-LABEL: circular_record
tt.func @circular_record() {
// CHECK: proton.init
// CHECK-NEXT: ttg.local_alloc
// CHECK-NEXT: ttg.global_scratch_alloc{{.*}}nbytes = 231487
// CHECK-NEXT: proton.circular_record
// CHECK-NEXT: proton.finalize{{.*}}{size = 231487 : i32}
// CHECK-NEXT: tt.return
%0 = proton.init : !tt.ptr<i32>
%1 = ttg.local_alloc {allocation.offset = 213016 : i32} : () -> !ttg.memdesc<4096xi32, #shared, #smem, mutable>
%2 = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 231487 : i32, ttg.global_scratch_memory_offset = 384 : i32} : !tt.ptr<i32>
proton.circular_record %1, %0 {isStart = false, regionId = 1 : i32} : !ttg.memdesc<4096xi32, #shared, #smem, mutable>, !tt.ptr<i32>
proton.finalize %1, %0, %2 {size = 231487 : i32} : !ttg.memdesc<4096xi32, #shared, #smem, mutable>, !tt.ptr<i32>, !tt.ptr<i32>
tt.return
}
} // end module
43 changes: 43 additions & 0 deletions test/Proton/proton_lowering.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// RUN: triton-opt --split-input-file %s -proton-lowering="max-shared-mem=1024 scratch-mem=1024 alignment=32" | FileCheck %s

// CHECK: module attributes{{.*}}ttg.global_scratch_memory_alignment = 32{{.*}}ttg.global_scratch_memory_size = 1024
module attributes {"ttg.num-warps" = 8 : i32, ttg.shared = 512 : i32} {
// CHECK-LABEL: sufficient_global_scratch_size
tt.func @sufficient_global_scratch_size() {
proton.record() {isStart = true, regionId = 1 : i32}
// CHECK: proton.finalize{{.*}}{size = 1024 : i32}
// CHECK-NEXT: tt.return
tt.return
}
} // end module

// -----

// CHECK: module attributes{{.*}}ttg.global_scratch_memory_alignment = 128{{.*}}ttg.global_scratch_memory_size = 1280
module attributes {ttg.global_scratch_memory_alignment = 128, ttg.global_scratch_memory_size = 150, "ttg.num-warps" = 8 : i32, ttg.shared = 512 : i32} {
// CHECK-LABEL: unalign_global_scratch_alloc
tt.func @unalign_global_scratch_alloc() {
proton.record() {isStart = true, regionId = 1 : i32}
// CHECK: ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 1024 : i32, ttg.global_scratch_memory_offset = 256 : i32}
// CHECK: proton.finalize{{.*}}{size = 1024 : i32}
// CHECK-NEXT: tt.return
tt.return
}
} // end module

// -----

// CHECK: module attributes{{.*}}ttg.global_scratch_memory_alignment = 64{{.*}}ttg.global_scratch_memory_size = 1152
module attributes {ttg.global_scratch_memory_alignment = 64, ttg.global_scratch_memory_size = 128, "ttg.num-warps" = 8 : i32, ttg.shared = 512 : i32} {
// CHECK-LABEL: align_global_scratch_alloc
tt.func @align_global_scratch_alloc() {
proton.record() {isStart = true, regionId = 1 : i32}
// CHECK: %[[ARG0:.*]] = proton.init
// CHECK-NEXT: %[[ARG1:.*]] = ttg.local_alloc
// CHECK-NEXT: %[[ARG2:.*]] = ttg.global_scratch_alloc {alignment = 64 : i32, nbytes = 1024 : i32, ttg.global_scratch_memory_offset = 128 : i32} : !tt.ptr<i32>
// CHECK-NEXT: proton.circular_record %[[ARG1]], %[[ARG0]] {isStart = true, regionId = 1 : i32} : !ttg.memdesc<128xi32, #shared, #smem, mutable>, !tt.ptr<i32>
// CHECK-NEXT: proton.finalize %[[ARG1]], %[[ARG0]], %[[ARG2]] {size = 1024 : i32} : !ttg.memdesc<128xi32, #shared, #smem, mutable>, !tt.ptr<i32>, !tt.ptr<i32>
// CHECK-NEXT: tt.return
tt.return
}
} // end module
6 changes: 5 additions & 1 deletion third_party/nvidia/backend/compiler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from triton.backends.compiler import BaseBackend, GPUTarget
from triton._C.libtriton import ir, passes, llvm, nvidia
from triton._C.libtriton import ir, passes, llvm, nvidia, proton
from triton.runtime.errors import PTXASError
import triton.intraprof as profiler

from dataclasses import dataclass
import functools
Expand Down Expand Up @@ -293,6 +294,9 @@ def make_llir(src, metadata, options, capability):
passes.convert.add_index_to_llvmir(pm)
passes.ttgpuir.add_allocate_shared_memory(pm)
passes.ttgpuir.add_allocate_global_scratch_memory(pm)
if profiler.intra.activated:
config = profiler.intra.state.config
proton.passes.add_proton_lowering(pm, config.max_shared, config.alloc_scratch, config.alignment)
nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version)
nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm)
passes.convert.add_arith_to_llvmir(pm)
Expand Down
2 changes: 2 additions & 0 deletions third_party/nvidia/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from triton.runtime import _allocation
from triton.backends.compiler import GPUTarget
from triton.backends.driver import GPUDriver
from triton.intraprof import set_alloc_state

dirname = os.path.dirname(os.path.realpath(__file__))
include_dir = [os.path.join(dirname, "include")]
Expand Down Expand Up @@ -518,6 +519,7 @@ def __call__(self, gridX, gridY, gridZ, stream, function, *args):
grid_size = gridX * gridY * gridZ
alloc_size = grid_size * self.global_scratch_size
global_scratch = _allocation._allocator(alloc_size, self.global_scratch_align, stream)
set_alloc_state(global_scratch, grid_size, self.global_scratch_size, self.global_scratch_align, stream)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it might be better to just add a new profile_buffer argument to the launch function

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This way global_scratch can be used as it was

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And we can profile kernels with both global_scratch and profile_buffer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My initial thought was actually something like

def alloc_fn():
    size = size + profile_buffer_size
    ret = torch.empty(size, device="cuda", dtype=torch.int8)
    libproton.set_profile_buffer(original pointer of ret + size)

But still it's not as clean as using an independent buffer

else:
global_scratch = None
self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, global_scratch, *args)
Expand Down
4 changes: 2 additions & 2 deletions third_party/proton/dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}/include)
add_subdirectory(include)
add_subdirectory(lib)
if(TRITON_BUILD_PYTHON_MODULE)
add_triton_plugin(TritonProton ${CMAKE_CURRENT_SOURCE_DIR}/triton_proton.cc)
target_link_libraries(TritonProton PRIVATE ProtonIR Python3::Module pybind11::headers)
add_triton_plugin(TritonProton ${CMAKE_CURRENT_SOURCE_DIR}/triton_proton.cc LINK_LIBS ProtonIR ProtonTransforms)
target_link_libraries(TritonProton PRIVATE Python3::Module pybind11::headers)
endif()
1 change: 1 addition & 0 deletions third_party/proton/dialect/include/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
add_subdirectory(Dialect)
add_subdirectory(Transforms)
18 changes: 14 additions & 4 deletions third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
#ifndef TRITON_DIALECT_PROTON_IR_DIALECT_H_
#define TRITON_DIALECT_PROTON_IR_DIALECT_H_
#ifndef DIALECT_PROTON_IR_DIALECT_H_
#define DIALECT_PROTON_IR_DIALECT_H_

#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/PatternMatch.h"
#include "proton/dialect/include/Dialect/Proton/IR/Dialect.h.inc"
#include "proton/dialect/include/Dialect/Proton/IR/OpsEnums.h.inc"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"

#define GET_ATTRDEF_CLASSES
#include "proton/dialect/include/Dialect/Proton/IR/ProtonAttrDefs.h.inc"
Expand All @@ -16,8 +18,16 @@

namespace mlir {
namespace triton {
namespace proton {} // namespace proton
namespace proton {

const int getGroupSize();

const int getBytesPerEntry();

const int getHeaderSize();

} // namespace proton
} // namespace triton
} // namespace mlir

#endif // TRITON_DIALECT_PROTON_IR_DIALECT_H_
#endif // DIALECT_PROTON_IR_DIALECT_H_
Loading
Loading