From 6f88125a5cb38bf3122544062bfc31308ca00863 Mon Sep 17 00:00:00 2001 From: Nithin Meganathan <18070964+nithinsubbiah@users.noreply.github.com> Date: Thu, 8 Aug 2024 16:57:36 -0700 Subject: [PATCH] [Codegen] Lower `hal.interface.workgroup.size` in GPU codegen (#18145) Cleanup related to https://github.com/iree-org/iree/issues/16554 --------- Signed-off-by: nithinsubbiah --- .../Codegen/LLVMGPU/ConvertToLLVM.cpp | 10 +++-- .../Codegen/LLVMGPU/test/convert_to_nvvm.mlir | 30 +++++++++++++++ .../LLVMGPU/test/convert_to_rocdl.mlir | 30 +++++++++++++++ .../Codegen/SPIRV/ConvertToSPIRVPass.cpp | 16 ++++---- .../Codegen/SPIRV/test/convert_to_spirv.mlir | 38 +++++++++++++++++++ 5 files changed, 113 insertions(+), 11 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp index 8319790f3182..821623ddf609 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp @@ -533,10 +533,12 @@ void populateConvertSharedMemoryAllocOps(RewritePatternSet &patterns) { } void populateLowerHALInterfaceOp(RewritePatternSet &patterns) { - patterns.insert, - HALInterfaceWorkgroupOpsConverter< - IREE::HAL::InterfaceWorkgroupCountOp, gpu::GridDimOp>>( + patterns.add, + HALInterfaceWorkgroupOpsConverter< + IREE::HAL::InterfaceWorkgroupSizeOp, gpu::BlockDimOp>, + HALInterfaceWorkgroupOpsConverter< + IREE::HAL::InterfaceWorkgroupCountOp, gpu::GridDimOp>>( patterns.getContext()); } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_nvvm.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_nvvm.mlir index 0bd40d0d034e..ea5f33f32520 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_nvvm.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_nvvm.mlir @@ -426,3 +426,33 @@ hal.executable @masked_load_store { // CHECK: %[[MASK_BIT:.+]] = llvm.icmp "sgt" {{.*}} : vector<1xi64> // CHECK: llvm.intr.masked.load %{{.*}}, %[[MASK_BIT]] // CHECK: llvm.intr.masked.store %{{.*}}, %[[MASK_BIT]] + +// ----- +// Test workgroup size lowering +#pipeline_layout = #hal.pipeline.layout, + #hal.descriptor_set.binding<1, storage_buffer>, + #hal.descriptor_set.binding<2, storage_buffer> + ]> +]> +hal.executable private @interface_wg_size { + hal.executable.variant @rocm target(<"cuda", "cuda-nvptx-fb">) { + hal.executable.export @interface_wg_size layout(#pipeline_layout) attributes { + workgroup_size = [32: index, 1: index, 1: index] + } + builtin.module attributes {} { + func.func @interface_wg_size() { + %c0 = arith.constant 0.0 : f32 + %workgroup_size_x = hal.interface.workgroup.size[0] : index + %workgroup_size_y = hal.interface.workgroup.size[1] : index + %subspan = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) : memref<64x64xf32> + memref.store %c0, %subspan[%workgroup_size_x, %workgroup_size_y] : memref<64x64xf32> + return + } + } + } +} +// CHECK-LABEL: llvm.func @interface_wg_size +// CHECK: %[[WGDIMX:.+]] = nvvm.read.ptx.sreg.ntid.x +// CHECK: %[[WGDIMY:.+]] = nvvm.read.ptx.sreg.ntid.y diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_rocdl.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_rocdl.mlir index b4dd4ca0a8ca..d5c993174407 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_rocdl.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_rocdl.mlir @@ -129,3 +129,33 @@ hal.executable @masked_load_store { // CHECK: %[[MASK_BIT:.+]] = llvm.icmp "sgt" {{.*}} : vector<1xi64> // CHECK: llvm.intr.masked.load %{{.*}}, %[[MASK_BIT]] // CHECK: llvm.intr.masked.store %{{.*}}, %[[MASK_BIT]] + +// ----- +// Test workgroup size lowering +#pipeline_layout = #hal.pipeline.layout, + #hal.descriptor_set.binding<1, storage_buffer>, + #hal.descriptor_set.binding<2, storage_buffer> + ]> +]> +hal.executable private @interface_wg_size { + hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { + hal.executable.export @interface_wg_size layout(#pipeline_layout) attributes { + workgroup_size = [32: index, 1: index, 1: index] + } + builtin.module attributes {} { + func.func @interface_wg_size() { + %c0 = arith.constant 0.0 : f32 + %workgroup_size_x = hal.interface.workgroup.size[0] : index + %workgroup_size_y = hal.interface.workgroup.size[1] : index + %subspan = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) : memref<64x64xf32> + memref.store %c0, %subspan[%workgroup_size_x, %workgroup_size_y] : memref<64x64xf32> + return + } + } + } +} +// CHECK-LABEL: llvm.func @interface_wg_size +// CHECK: %[[WGDIMX:.+]] = rocdl.workgroup.dim.x +// CHECK: %[[WGDIMY:.+]] = rocdl.workgroup.dim.y diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp index 442d530b0509..0aad84dcff4f 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp @@ -302,10 +302,10 @@ struct HALInterfaceLoadConstantConverter final } }; -/// A pattern to convert hal.interface.workgroup.id/count into corresponding -/// SPIR-V Builtin ops. +/// A pattern to convert hal.interface.workgroup.id/count/size into +/// corresponding SPIR-V Builtin ops. template -struct HALInterfaceWorkgroupIdAndCountConverter final +struct HALInterfaceWorkgroupOpsConverter final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -656,10 +656,12 @@ void ConvertToSPIRVPass::runOnOperation() { // Add IREE HAL interface op conversions. patterns.add< HALInterfaceLoadConstantConverter, - HALInterfaceWorkgroupIdAndCountConverter< - IREE::HAL::InterfaceWorkgroupIDOp, spirv::BuiltIn::WorkgroupId>, - HALInterfaceWorkgroupIdAndCountConverter< - IREE::HAL::InterfaceWorkgroupCountOp, spirv::BuiltIn::NumWorkgroups>>( + HALInterfaceWorkgroupOpsConverter, + HALInterfaceWorkgroupOpsConverter, + HALInterfaceWorkgroupOpsConverter>( typeConverter, context); // Performs a prelimiary step to analyze all hal.interface.binding.subspan ops diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir index b51836711be0..9fcdab10a188 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir @@ -240,6 +240,44 @@ hal.executable private @interface_wg_id { // ----- +#pipeline_layout = #hal.pipeline.layout, + #hal.descriptor_set.binding<1, storage_buffer>, + #hal.descriptor_set.binding<2, storage_buffer> + ]> +]> +hal.executable private @interface_wg_size { + hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb">) { + hal.executable.export @interface_wg_size layout(#pipeline_layout) attributes { + workgroup_size = [32: index, 1: index, 1: index] + } + builtin.module attributes {spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { + func.func @interface_wg_size() { + %c0 = arith.constant 0.0 : f32 + %workgroup_size_x = hal.interface.workgroup.size[0] : index + %workgroup_size_y = hal.interface.workgroup.size[1] : index + %subspan = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) : memref<64x64xf32, #spirv.storage_class> + memref.store %c0, %subspan[%workgroup_size_x, %workgroup_size_y] : memref<64x64xf32, #spirv.storage_class> + return + } + } + } +} + +// CHECK-LABEL: spirv.module +// CHECK-DAG: spirv.GlobalVariable @[[WGSIZE:.+]] built_in("WorkgroupSize") +// CHECK-DAG: spirv.GlobalVariable @[[BIND:.+]] bind(0, 0) +// CHECK: %[[CST0:.+]] = spirv.Constant 0.000000e+00 : f32 +// CHECK: %[[ADDR1:.+]] = spirv.mlir.addressof @[[WGSIZE]] +// CHECK: %[[VAL1:.+]] = spirv.Load "Input" %[[ADDR1:.+]] +// CHECK: %[[WGSIZEX:.+]] = spirv.CompositeExtract %[[VAL1]][0 : i32] +// CHECK: %[[ADDR2:.+]] = spirv.mlir.addressof @[[WGSIZE]] +// CHECK: %[[VAL2:.+]] = spirv.Load "Input" %[[ADDR2:.+]] +// CHECK: %[[WGSIZEY:.+]] = spirv.CompositeExtract %[[VAL2]][1 : i32] + +// ----- + #pipeline_layout = #hal.pipeline.layout,