Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMD] Send target features to backend #5565

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/test/unit/tools/test_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def test_ttgir_to_ptx():
kernel_path = os.path.join(tmp_dir, "empty_kernel.ttgir")
with open(kernel_path, "w") as fp:
fp.write(src)
k = triton.compile(kernel_path, target=GPUTarget("cuda", 80, 32))
k = triton.compile(kernel_path, target=GPUTarget("cuda", 80, 32, ""))
ptx = k.asm["ptx"]
assert ".target sm_80" in ptx
assert ".address_size 64" in ptx
1 change: 1 addition & 0 deletions python/triton/backends/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class GPUTarget(object):
# Target architecture, e.g., 90 (for cuda compute capability), gfx940 (for hip)
arch: Union[int, str]
warp_size: int
features: str


class BaseBackend(metaclass=ABCMeta):
Expand Down
2 changes: 1 addition & 1 deletion python/triton/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def __init__(self, src, metadata_group, hash):
metadata['cluster_dims'] = tuple(metadata['cluster_dims'])
# JSON serialization dumps the target as a dict. Restore it to a GPUTarget.
target = metadata['target']
metadata['target'] = GPUTarget(target['backend'], target['arch'], target['warp_size'])
metadata['target'] = GPUTarget(target['backend'], target['arch'], target['warp_size'], target['features'])
KernelMetadata = namedtuple('KernelMetadata', sorted(list(metadata.keys())))
self.metadata = KernelMetadata(**metadata)
backend = make_backend(self.metadata.target)
Expand Down
23 changes: 11 additions & 12 deletions third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class HIPOptions:
debug: bool = False
sanitize_overflow: bool = True
arch: str = None
features: str = None
supported_fp8_dtypes: Tuple[str] = ("fp8e5", )
deprecated_fp8_dtypes: Tuple[str] = ()
default_dot_input_precision: str = "ieee"
Expand Down Expand Up @@ -110,7 +111,10 @@ def __init__(self, target: GPUTarget) -> None:
self.binary_ext = "hsaco"

def parse_options(self, opts) -> Any:
args = {'arch': self.target.arch}
target_features = self.target.features.split(',')
if os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
target_features.append('+xnack')
args = {'arch': self.target.arch, 'features': ','.join(target_features)}

if "supported_fp8_dtypes" not in opts:
supported_fp8_dtypes = set(HIPOptions.supported_fp8_dtypes)
Expand Down Expand Up @@ -301,10 +305,7 @@ def make_llir(src, metadata, options):
context = llvm.context()
llvm_mod = llvm.to_module(mod, context)
amd.attach_target_triple(llvm_mod)
target_features = ''
if os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
target_features = '+xnack'
llvm.attach_datalayout(llvm_mod, amd.TARGET_TRIPLE, options.arch, target_features)
llvm.attach_datalayout(llvm_mod, amd.TARGET_TRIPLE, options.arch, options.features)

# Set various control constants on the LLVM module so that device
# libraries can resolve references to them.
Expand All @@ -330,8 +331,8 @@ def make_llir(src, metadata, options):
fns[0].add_fn_attr("amdgpu-waves-per-eu", f"{options.waves_per_eu}")
denormal_mode = "preserve-sign" if options.allow_flush_denorm else "ieee"
fns[0].add_fn_attr("denormal-fp-math-f32", denormal_mode)
fns[0].add_fn_target_feature(options.features)
if os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
fns[0].add_fn_target_feature("+xnack")
fns[0].add_fn_asan_attr()

# Hint the compiler that we'd like the firmware to set the kernel arguments
Expand All @@ -351,7 +352,7 @@ def make_llir(src, metadata, options):
paths = [path for (name, path) in options.extern_libs if amd.need_extern_lib(llvm_mod, name)]
llvm.link_extern_libs(llvm_mod, paths)

llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, options.arch, '', [], options.enable_fp_fusion)
llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, options.arch, options.features, [], options.enable_fp_fusion)

# Get some metadata
metadata["shared"] = src.get_int_attr("ttg.shared")
Expand All @@ -371,18 +372,16 @@ def make_amdgcn(src, metadata, options):
assert len(names) == 1
metadata["name"] = names[0]
# llvm -> hsaco
amdgcn = llvm.translate_to_asm(src, amd.TARGET_TRIPLE, options.arch, '', [], options.enable_fp_fusion, False)
amdgcn = llvm.translate_to_asm(src, amd.TARGET_TRIPLE, options.arch, options.features, [],
options.enable_fp_fusion, False)
if os.environ.get("AMDGCN_ENABLE_DUMP", "0") == "1":
print("// -----// AMDGCN Dump //----- //")
print(amdgcn)
return amdgcn

@staticmethod
def make_hsaco(src, metadata, options):
target_features = ''
if os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
target_features = '+xnack'
hsaco = amd.assemble_amdgcn(src, options.arch, target_features)
hsaco = amd.assemble_amdgcn(src, options.arch, options.features)

rocm_path = HIPBackend.path_to_rocm_lld()
with tempfile.NamedTemporaryFile() as tmp_out:
Expand Down
10 changes: 8 additions & 2 deletions third_party/amd/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,9 +509,15 @@ def is_active():
def get_current_target(self):
device = self.get_current_device()
device_properties = self.utils.get_device_properties(device)
arch = device_properties['arch']
warp_size = device_properties['warpSize']
return GPUTarget("hip", arch.split(':')[0], warp_size)
arch = device_properties['arch'].split(':', 1)
arch_name = arch[0]
# reformat features
arch_features = []
for feat in arch[1].split(':'):
modifier = feat[-1:]
arch_features.append(modifier + feat[:-1])
return GPUTarget("hip", arch_name, warp_size, ','.join(arch_features))

def get_active_torch_device(self):
import torch
Expand Down
2 changes: 1 addition & 1 deletion third_party/nvidia/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ def get_current_target(self):
capability = self.get_device_capability(device)
capability = capability[0] * 10 + capability[1]
warp_size = 32
return GPUTarget("cuda", capability, warp_size)
return GPUTarget("cuda", capability, warp_size, "")

def get_active_torch_device(self):
import torch
Expand Down
Loading