From 0d2c78073ed2f5092e90e9fe39ad0dbaa3e109c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrea=20=F0=9F=A6=88?= Date: Tue, 25 Jun 2024 18:44:09 +0200 Subject: [PATCH] Ensure IREE GPU dialect is registered for all GPU targets (fixes #17736) (#17737) Before this commit, it was only registered for ROCm, so the compiler would sometimes crash during compilation (including for some things compiled as part of the IREE build process) when trying to use IREE GPU dialect attributes if the ROCm target wasn't enabled. Signed-off-by: Andrea Faulds --- compiler/plugins/target/CUDA/CUDATarget.cpp | 2 +- compiler/plugins/target/MetalSPIRV/MetalSPIRVTarget.cpp | 3 ++- compiler/plugins/target/VulkanSPIRV/VulkanSPIRVTarget.cpp | 2 +- compiler/plugins/target/WebGPUSPIRV/WebGPUSPIRVTarget.cpp | 3 ++- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/compiler/plugins/target/CUDA/CUDATarget.cpp b/compiler/plugins/target/CUDA/CUDATarget.cpp index 4d61464bdf18..7e78ae3eeb2c 100644 --- a/compiler/plugins/target/CUDA/CUDATarget.cpp +++ b/compiler/plugins/target/CUDA/CUDATarget.cpp @@ -450,7 +450,7 @@ class CUDATargetBackend final : public TargetBackend { // `LLVMGPULowerExecutableTargetPass`. registry.insert(); + transform::TransformDialect, IREE::GPU::IREEGPUDialect>(); mlir::registerBuiltinDialectTranslation(registry); mlir::registerLLVMDialectTranslation(registry); mlir::registerNVVMDialectTranslation(registry); diff --git a/compiler/plugins/target/MetalSPIRV/MetalSPIRVTarget.cpp b/compiler/plugins/target/MetalSPIRV/MetalSPIRVTarget.cpp index 25e8e51843c9..4fa2b03c1094 100644 --- a/compiler/plugins/target/MetalSPIRV/MetalSPIRVTarget.cpp +++ b/compiler/plugins/target/MetalSPIRV/MetalSPIRVTarget.cpp @@ -112,7 +112,8 @@ class MetalSPIRVTargetBackend : public TargetBackend { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); + IREE::Flow::FlowDialect, spirv::SPIRVDialect, + IREE::GPU::IREEGPUDialect>(); } void diff --git a/compiler/plugins/target/VulkanSPIRV/VulkanSPIRVTarget.cpp b/compiler/plugins/target/VulkanSPIRV/VulkanSPIRVTarget.cpp index 5fdeb54ff693..33b0ac1e5969 100644 --- a/compiler/plugins/target/VulkanSPIRV/VulkanSPIRVTarget.cpp +++ b/compiler/plugins/target/VulkanSPIRV/VulkanSPIRVTarget.cpp @@ -126,7 +126,7 @@ class VulkanSPIRVTargetBackend : public TargetBackend { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); + gpu::GPUDialect, IREE::GPU::IREEGPUDialect>(); } void diff --git a/compiler/plugins/target/WebGPUSPIRV/WebGPUSPIRVTarget.cpp b/compiler/plugins/target/WebGPUSPIRV/WebGPUSPIRVTarget.cpp index 0a376912652a..8fd7c531316d 100644 --- a/compiler/plugins/target/WebGPUSPIRV/WebGPUSPIRVTarget.cpp +++ b/compiler/plugins/target/WebGPUSPIRV/WebGPUSPIRVTarget.cpp @@ -106,7 +106,8 @@ class WebGPUSPIRVTargetBackend : public TargetBackend { // pipeline created by buildTranslationPassPipeline) void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); + spirv::SPIRVDialect, gpu::GPUDialect, + IREE::GPU::IREEGPUDialect>(); } void