Skip to content

Commit

Permalink
[Codegen] Lower hal.interface.workgroup.size in GPU codegen (iree-o…
Browse files Browse the repository at this point in the history
…rg#18145)

Cleanup related to iree-org#16554

---------

Signed-off-by: nithinsubbiah <nithinsubbiah@gmail.com>
  • Loading branch information
nithinsubbiah authored Aug 8, 2024
1 parent a7e5788 commit 6f88125
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 11 deletions.
10 changes: 6 additions & 4 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -533,10 +533,12 @@ void populateConvertSharedMemoryAllocOps(RewritePatternSet &patterns) {
}

void populateLowerHALInterfaceOp(RewritePatternSet &patterns) {
patterns.insert<HALInterfaceWorkgroupOpsConverter<
IREE::HAL::InterfaceWorkgroupIDOp, gpu::BlockIdOp>,
HALInterfaceWorkgroupOpsConverter<
IREE::HAL::InterfaceWorkgroupCountOp, gpu::GridDimOp>>(
patterns.add<HALInterfaceWorkgroupOpsConverter<
IREE::HAL::InterfaceWorkgroupIDOp, gpu::BlockIdOp>,
HALInterfaceWorkgroupOpsConverter<
IREE::HAL::InterfaceWorkgroupSizeOp, gpu::BlockDimOp>,
HALInterfaceWorkgroupOpsConverter<
IREE::HAL::InterfaceWorkgroupCountOp, gpu::GridDimOp>>(
patterns.getContext());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -426,3 +426,33 @@ hal.executable @masked_load_store {
// CHECK: %[[MASK_BIT:.+]] = llvm.icmp "sgt" {{.*}} : vector<1xi64>
// CHECK: llvm.intr.masked.load %{{.*}}, %[[MASK_BIT]]
// CHECK: llvm.intr.masked.store %{{.*}}, %[[MASK_BIT]]

// -----
// Test workgroup size lowering
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
#hal.descriptor_set.binding<1, storage_buffer>,
#hal.descriptor_set.binding<2, storage_buffer>
]>
]>
hal.executable private @interface_wg_size {
hal.executable.variant @rocm target(<"cuda", "cuda-nvptx-fb">) {
hal.executable.export @interface_wg_size layout(#pipeline_layout) attributes {
workgroup_size = [32: index, 1: index, 1: index]
}
builtin.module attributes {} {
func.func @interface_wg_size() {
%c0 = arith.constant 0.0 : f32
%workgroup_size_x = hal.interface.workgroup.size[0] : index
%workgroup_size_y = hal.interface.workgroup.size[1] : index
%subspan = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) : memref<64x64xf32>
memref.store %c0, %subspan[%workgroup_size_x, %workgroup_size_y] : memref<64x64xf32>
return
}
}
}
}
// CHECK-LABEL: llvm.func @interface_wg_size
// CHECK: %[[WGDIMX:.+]] = nvvm.read.ptx.sreg.ntid.x
// CHECK: %[[WGDIMY:.+]] = nvvm.read.ptx.sreg.ntid.y
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,33 @@ hal.executable @masked_load_store {
// CHECK: %[[MASK_BIT:.+]] = llvm.icmp "sgt" {{.*}} : vector<1xi64>
// CHECK: llvm.intr.masked.load %{{.*}}, %[[MASK_BIT]]
// CHECK: llvm.intr.masked.store %{{.*}}, %[[MASK_BIT]]

// -----
// Test workgroup size lowering
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
#hal.descriptor_set.binding<1, storage_buffer>,
#hal.descriptor_set.binding<2, storage_buffer>
]>
]>
hal.executable private @interface_wg_size {
hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {
hal.executable.export @interface_wg_size layout(#pipeline_layout) attributes {
workgroup_size = [32: index, 1: index, 1: index]
}
builtin.module attributes {} {
func.func @interface_wg_size() {
%c0 = arith.constant 0.0 : f32
%workgroup_size_x = hal.interface.workgroup.size[0] : index
%workgroup_size_y = hal.interface.workgroup.size[1] : index
%subspan = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) : memref<64x64xf32>
memref.store %c0, %subspan[%workgroup_size_x, %workgroup_size_y] : memref<64x64xf32>
return
}
}
}
}
// CHECK-LABEL: llvm.func @interface_wg_size
// CHECK: %[[WGDIMX:.+]] = rocdl.workgroup.dim.x
// CHECK: %[[WGDIMY:.+]] = rocdl.workgroup.dim.y
16 changes: 9 additions & 7 deletions compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,10 +302,10 @@ struct HALInterfaceLoadConstantConverter final
}
};

/// A pattern to convert hal.interface.workgroup.id/count into corresponding
/// SPIR-V Builtin ops.
/// A pattern to convert hal.interface.workgroup.id/count/size into
/// corresponding SPIR-V Builtin ops.
template <typename InterfaceOpTy, spirv::BuiltIn builtin>
struct HALInterfaceWorkgroupIdAndCountConverter final
struct HALInterfaceWorkgroupOpsConverter final
: OpConversionPattern<InterfaceOpTy> {
using OpConversionPattern<InterfaceOpTy>::OpConversionPattern;

Expand Down Expand Up @@ -656,10 +656,12 @@ void ConvertToSPIRVPass::runOnOperation() {
// Add IREE HAL interface op conversions.
patterns.add<
HALInterfaceLoadConstantConverter,
HALInterfaceWorkgroupIdAndCountConverter<
IREE::HAL::InterfaceWorkgroupIDOp, spirv::BuiltIn::WorkgroupId>,
HALInterfaceWorkgroupIdAndCountConverter<
IREE::HAL::InterfaceWorkgroupCountOp, spirv::BuiltIn::NumWorkgroups>>(
HALInterfaceWorkgroupOpsConverter<IREE::HAL::InterfaceWorkgroupIDOp,
spirv::BuiltIn::WorkgroupId>,
HALInterfaceWorkgroupOpsConverter<IREE::HAL::InterfaceWorkgroupSizeOp,
spirv::BuiltIn::WorkgroupSize>,
HALInterfaceWorkgroupOpsConverter<IREE::HAL::InterfaceWorkgroupCountOp,
spirv::BuiltIn::NumWorkgroups>>(
typeConverter, context);

// Performs a prelimiary step to analyze all hal.interface.binding.subspan ops
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,44 @@ hal.executable private @interface_wg_id {

// -----

#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
#hal.descriptor_set.binding<1, storage_buffer>,
#hal.descriptor_set.binding<2, storage_buffer>
]>
]>
hal.executable private @interface_wg_size {
hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb">) {
hal.executable.export @interface_wg_size layout(#pipeline_layout) attributes {
workgroup_size = [32: index, 1: index, 1: index]
}
builtin.module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Int64, Shader], []>, #spirv.resource_limits<>>} {
func.func @interface_wg_size() {
%c0 = arith.constant 0.0 : f32
%workgroup_size_x = hal.interface.workgroup.size[0] : index
%workgroup_size_y = hal.interface.workgroup.size[1] : index
%subspan = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) : memref<64x64xf32, #spirv.storage_class<StorageBuffer>>
memref.store %c0, %subspan[%workgroup_size_x, %workgroup_size_y] : memref<64x64xf32, #spirv.storage_class<StorageBuffer>>
return
}
}
}
}

// CHECK-LABEL: spirv.module
// CHECK-DAG: spirv.GlobalVariable @[[WGSIZE:.+]] built_in("WorkgroupSize")
// CHECK-DAG: spirv.GlobalVariable @[[BIND:.+]] bind(0, 0)
// CHECK: %[[CST0:.+]] = spirv.Constant 0.000000e+00 : f32
// CHECK: %[[ADDR1:.+]] = spirv.mlir.addressof @[[WGSIZE]]
// CHECK: %[[VAL1:.+]] = spirv.Load "Input" %[[ADDR1:.+]]
// CHECK: %[[WGSIZEX:.+]] = spirv.CompositeExtract %[[VAL1]][0 : i32]
// CHECK: %[[ADDR2:.+]] = spirv.mlir.addressof @[[WGSIZE]]
// CHECK: %[[VAL2:.+]] = spirv.Load "Input" %[[ADDR2:.+]]
// CHECK: %[[WGSIZEY:.+]] = spirv.CompositeExtract %[[VAL2]][1 : i32]

// -----

#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
Expand Down

0 comments on commit 6f88125

Please sign in to comment.