From bcc4ad5632b051d5800db5a831465779a4f17151 Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Mon, 9 Dec 2024 15:31:20 -0500 Subject: [PATCH] [tuner]: use property function from iree lowering config python binding (#662) After landing https://github.com/iree-org/iree/pull/19376, all helper functions related to lowering configuration can be removed. Instead, we can directly utilize property functions from the LoweringConfig Python bindings. This PR is still relevant to the task in https://github.com/nod-ai/shark-ai/issues/453: use IREE bindings for compilation info (incl., lowering_config and translation_info). --------- Signed-off-by: Bangtian Liu --- tuner/tuner/candidate_gen.py | 84 ++++++++++++++--------------- tuner/tuner/candidate_gen_test.py | 36 ++++++------- tuner/tuner/common.py | 34 ------------ tuner/tuner/common_test.py | 5 +- tuner/tuner/dispatch_parser.py | 36 +------------ tuner/tuner/dispatch_parser_test.py | 29 +++------- 6 files changed, 68 insertions(+), 156 deletions(-) diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index c903ec85f..bc01bb709 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -38,16 +38,19 @@ tune_logger = logging.getLogger("tune") -# TODO: remove the argument 'workgroup_sizes' and 'reduction_sizes'. + def apply_configuration( template: list[str], configuration: Configuration, - workgroup_sizes: list[int], - reduction_sizes: list[int], ) -> str: - intrinsic = get_intrinsic(configuration) - subgroup_m_count = get_subgroup_m_count(configuration) - subgroup_n_count = get_subgroup_n_count(configuration) + lowering_config = configuration.lowering_config + intrinsic = lowering_config.mma_kind + ( + subgroup_m_count, + subgroup_n_count, + ) = lowering_config.subgroup_count_mn + workgroup_sizes = lowering_config.workgroup_tile_sizes + reduction_sizes = lowering_config.reduction_tile_sizes tune_logger.info(f"Applying: {configuration}") expr0 = re.compile( r", subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>" @@ -125,9 +128,12 @@ class MmtTuner(DispatchTuner, MmtParser): def get_transform_function_mmt( self, problem_size: ProblemSize, functionName: str, configuration: Configuration ) -> str: - intrinsic = get_intrinsic(configuration) - subgroup_m_count = get_subgroup_m_count(configuration) - subgroup_n_count = get_subgroup_n_count(configuration) + lowering_config = configuration.lowering_config + intrinsic = lowering_config.mma_kind + ( + subgroup_m_count, + subgroup_n_count, + ) = lowering_config.subgroup_count_mn wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -167,8 +173,6 @@ def apply_params( modified += apply_configuration( template, configuration, - get_mmt_workgroup_sizes(configuration), - get_mmt_reduction_sizes(configuration), ) embeddable = indent( self.get_transform_function_mmt(problem_size, f"match_op", configuration), @@ -193,15 +197,12 @@ def get_transform_function_conv( filter = f"tensor<{problem_size.rhs_type}>" output = f"tensor<{dynamic_batch_output_ty}>" - workgroup_sizes = ", ".join( - map(str, self.get_conv_workgroup_sizes(configuration)) - ) - reduction_sizes = ", ".join( - map(str, self.get_conv_reduction_sizes(configuration)) - ) - intrinsic = get_intrinsic(configuration) - subgroup_m_count = get_subgroup_m_count(configuration) - subgroup_n_count = get_subgroup_n_count(configuration) + lowering_config = configuration.lowering_config + intrinsic = lowering_config.mma_kind + ( + subgroup_m_count, + subgroup_n_count, + ) = lowering_config.subgroup_count_mn wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -246,8 +247,6 @@ def apply_params( modified += apply_configuration( template, configuration, - self.get_conv_workgroup_sizes(configuration), - self.get_conv_reduction_sizes(configuration), ) embeddable = indent( self.get_transform_function_conv(problem_size, f"match_op", configuration), @@ -263,15 +262,12 @@ def get_transform_function_broadcast_rhs_mmt( functionName: str, configuration: Configuration, ) -> str: - workgroup_sizes = ", ".join( - map(str, get_batch_mmt_workgroup_sizes(configuration)) - ) - reduction_sizes = ", ".join( - map(str, get_batch_mmt_reduction_sizes(configuration)) - ) - intrinsic = get_intrinsic(configuration) - subgroup_m_count = get_subgroup_m_count(configuration) - subgroup_n_count = get_subgroup_n_count(configuration) + lowering_config = configuration.lowering_config + intrinsic = lowering_config.mma_kind + ( + subgroup_m_count, + subgroup_n_count, + ) = lowering_config.subgroup_count_mn wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -316,8 +312,6 @@ def apply_params_broadcast_rhs_mmt( modified += apply_configuration( template, configuration, - get_batch_mmt_workgroup_sizes(configuration), - get_batch_mmt_reduction_sizes(configuration), ) embeddable = indent( @@ -345,8 +339,6 @@ def apply_params( apply_configuration( template, configuration, - get_contract_workgroup_sizes(configuration, self.tile_dims), - get_contract_reduction_sizes(configuration, self.tile_dims), ), "", ) @@ -359,9 +351,12 @@ def get_transform_function_batch_mmt( functionName: str, configuration: Configuration, ) -> str: - intrinsic = get_intrinsic(configuration) - subgroup_m_count = get_subgroup_m_count(configuration) - subgroup_n_count = get_subgroup_n_count(configuration) + lowering_config = configuration.lowering_config + intrinsic = lowering_config.mma_kind + ( + subgroup_m_count, + subgroup_n_count, + ) = lowering_config.subgroup_count_mn wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -403,8 +398,6 @@ def apply_params( modified += apply_configuration( template, configuration, - get_batch_mmt_workgroup_sizes(configuration), - get_batch_mmt_reduction_sizes(configuration), ) embeddable = indent( @@ -428,9 +421,12 @@ def get_transform_function_batch_matmul( input1 = f"tensor<{problem_size.rhs_type}>" output = f"tensor<{problem_size.res_type}>" - intrinsic = get_intrinsic(configuration) - subgroup_m_count = get_subgroup_m_count(configuration) - subgroup_n_count = get_subgroup_n_count(configuration) + lowering_config = configuration.lowering_config + intrinsic = lowering_config.mma_kind + ( + subgroup_m_count, + subgroup_n_count, + ) = lowering_config.subgroup_count_mn wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -476,8 +472,6 @@ def apply_params( modified += apply_configuration( template, configuration, - get_contract_workgroup_sizes(configuration, self.tile_dims), - get_contract_reduction_sizes(configuration, self.tile_dims), ) embeddable = indent( diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py index 11de8a900..45da323c5 100644 --- a/tuner/tuner/candidate_gen_test.py +++ b/tuner/tuner/candidate_gen_test.py @@ -106,15 +106,15 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None: 'gpu_pipeline_options = #iree_gpu.pipeline_options, {llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}', ] - n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 640 + n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 16 mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) lowering_config = common.get_lowering_config( tuner_ctx=tuner_ctx, mma_kind=mma_attr, - workgroup=[464, 320, 0], - reduction=[0, 0, 16], + workgroup=[n, oh, ow, oc, fh, fw, 0], + reduction=[0, 0, 0, 0, 0, 0, ic], subgroup_m_count=1, subgroup_n_count=4, ) @@ -155,7 +155,7 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None: "LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64" in modified ) - assert "workgroup = [1, 1, 464, 320, 1, 1, 0]" in modified + assert "workgroup = [2, 64, 64, 640, 3, 3, 0]" in modified assert "reduction = [0, 0, 0, 0, 0, 0, 16]" in modified assert ( "gpu_pipeline_options = #iree_gpu.pipeline_options>" @@ -186,8 +186,8 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None: lowering_config = common.get_lowering_config( tuner_ctx=tuner_ctx, mma_kind=mma_attr, - workgroup=[480, 384, 0], - reduction=[0, 0, 32], + workgroup=[1, 480, 384, 0], + reduction=[0, 0, 0, 32], subgroup_m_count=1, subgroup_n_count=4, ) @@ -241,8 +241,8 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None: lowering_config = common.get_lowering_config( tuner_ctx=tuner_ctx, mma_kind=mma_attr, - workgroup=[416, 320, 0], - reduction=[0, 0, 128], + workgroup=[1, 416, 320, 0], + reduction=[0, 0, 0, 128], subgroup_m_count=2, subgroup_n_count=2, ) @@ -299,8 +299,8 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None: lowering_config = common.get_lowering_config( tuner_ctx=tuner_ctx, mma_kind=mma_attr, - workgroup=[128, 64, 0], - reduction=[0, 0, 128], + workgroup=[1, 128, 64, 0], + reduction=[0, 0, 0, 128], subgroup_m_count=2, subgroup_n_count=2, ) @@ -355,8 +355,8 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: lowering_config = common.get_lowering_config( tuner_ctx=tuner_ctx, mma_kind=mma_attr, - workgroup=[128, 64, 0], - reduction=[0, 0, 128], + workgroup=[1, 128, 64, 0], + reduction=[0, 0, 0, 128], subgroup_m_count=2, subgroup_n_count=2, ) @@ -408,8 +408,8 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: "%config = transform.param.constant #iree_codegen.compilation_info<" in embeddable ) - assert "workgroup = [128, 64, 0]" in embeddable - assert "reduction = [0, 0, 128]" in embeddable + assert "workgroup = [1, 128, 64, 0]" in embeddable + assert "reduction = [0, 0, 0, 128]" in embeddable assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable @@ -435,8 +435,8 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: lowering_config = common.get_lowering_config( tuner_ctx=tuner_ctx, mma_kind=mma_attr, - workgroup=[128, 64, 0], - reduction=[0, 0, 128], + workgroup=[1, 128, 64, 0], + reduction=[0, 0, 0, 128], subgroup_m_count=2, subgroup_n_count=2, ) @@ -492,8 +492,8 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: "%config = transform.param.constant #iree_codegen.compilation_info<" in embeddable ) - assert "workgroup = [128, 64, 0]" in embeddable - assert "reduction = [0, 0, 128]" in embeddable + assert "workgroup = [1, 128, 64, 0]" in embeddable + assert "reduction = [0, 0, 0, 128]" in embeddable assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index 702008f5e..0a2b03fd1 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -119,40 +119,6 @@ class Configuration: waves_per_eu: int -def get_intrinsic(config: Configuration) -> Optional[iree_gpu.MMAAttr]: - if "mma_kind" in config.lowering_config.attributes: - return config.lowering_config.attributes["mma_kind"] - return None - - -def get_workgroup_tile_sizes(config: Configuration) -> list[int]: - if "workgroup" in config.lowering_config.attributes: - workgroup_attrs = config.lowering_config.attributes["workgroup"] - return [attr.value for attr in workgroup_attrs] - return [] - - -def get_reduction_tile_sizes(config: Configuration) -> list[int]: - if "reduction" in config.lowering_config.attributes: - reduction_attrs = config.lowering_config.attributes["reduction"] - return [attr.value for attr in reduction_attrs] - return [] - - -def get_subgroup_m_count(config: Configuration) -> Optional[int]: - if "subgroup_m_count" in config.lowering_config.attributes: - attr = config.lowering_config.attributes["subgroup_m_count"] - return attr.value - return None - - -def get_subgroup_n_count(config: Configuration) -> Optional[int]: - if "subgroup_n_count" in config.lowering_config.attributes: - attr = config.lowering_config.attributes["subgroup_n_count"] - return attr.value - return None - - def get_lowering_config( tuner_ctx: TunerContext, **kwargs: Any, diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py index f13aed3d7..6d76c216f 100644 --- a/tuner/tuner/common_test.py +++ b/tuner/tuner/common_test.py @@ -215,6 +215,5 @@ def test_get_lowering_config(tuner_ctx: common.TunerContext) -> None: waves_per_eu=2, ) - assert common.get_intrinsic(config) is None - assert common.get_subgroup_m_count(config) == 1 - assert common.get_subgroup_n_count(config) == 1 + assert config.lowering_config.mma_kind is None + assert config.lowering_config.subgroup_count_mn == (1, 1) diff --git a/tuner/tuner/dispatch_parser.py b/tuner/tuner/dispatch_parser.py index ad63ba815..cc63c89a3 100644 --- a/tuner/tuner/dispatch_parser.py +++ b/tuner/tuner/dispatch_parser.py @@ -20,18 +20,10 @@ def parse_tensor_type(tensor_type: str) -> ShapedType: return ShapedType(shaped_ty.shape, shaped_ty.element_type) -def get_mmt_workgroup_sizes(configuration: Configuration): - return get_workgroup_tile_sizes(configuration) - - -def get_mmt_reduction_sizes(configuration: Configuration): - return get_reduction_tile_sizes(configuration) - - def get_contract_workgroup_sizes( configuration: Configuration, tile_dims: str ) -> list[int]: - m, n, _k = get_workgroup_tile_sizes(configuration) + m, n, _k = configuration.lowering_config.workgroup_tile_sizes workgroup_size = [1] * len(tile_dims) for idx, dim in enumerate(tile_dims): @@ -48,7 +40,7 @@ def get_contract_workgroup_sizes( def get_contract_reduction_sizes( configuration: Configuration, tile_dims: str ) -> list[int]: - _m, _n, k = get_reduction_tile_sizes(configuration) + _m, _n, k = configuration.lowering_config.reduction_tile_sizes reduction_size = [0] * len(tile_dims) for idx, dim in enumerate(tile_dims): if dim == "k": @@ -57,14 +49,6 @@ def get_contract_reduction_sizes( return reduction_size -def get_batch_mmt_workgroup_sizes(configuration: Configuration) -> list[int]: - return [1] + get_workgroup_tile_sizes(configuration) - - -def get_batch_mmt_reduction_sizes(configuration: Configuration) -> list[int]: - return [0] + get_reduction_tile_sizes(configuration) - - class MlirRegex(Enum): ssa_value = r"%[a-zA-Z0-9-_]+" tensor_type = r"tensor<([^>]+)>" @@ -164,22 +148,6 @@ class ConvParser(DispatchParser): def supports(self, op_name: str) -> bool: return "conv_2d_nhwc_hwcf" in op_name - def get_conv_workgroup_sizes(self, configuration: Configuration) -> list[int]: - batch = 1 - fh = 1 - fw = 1 - - oh = 1 - - ow, oc, _ic = get_workgroup_tile_sizes(configuration) - - return [batch, oh, ow, oc, fh, fw, 0] - - def get_conv_reduction_sizes(self, configuration: Configuration) -> list[int]: - _ow, _oc, ic = get_reduction_tile_sizes(configuration) - - return [0, 0, 0, 0, 0, 0, ic] - def get_shapes(self, template: list[str]) -> ProblemSize: for line in template: if "linalg.conv_2d_nhwc_hwcf" not in line: diff --git a/tuner/tuner/dispatch_parser_test.py b/tuner/tuner/dispatch_parser_test.py index 650540c63..db8c4a7da 100644 --- a/tuner/tuner/dispatch_parser_test.py +++ b/tuner/tuner/dispatch_parser_test.py @@ -57,8 +57,9 @@ def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None: gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=0, ) - assert dispatch_parser.get_mmt_workgroup_sizes(config) == [128, 320, 0] - assert dispatch_parser.get_mmt_reduction_sizes(config) == [0, 0, 32] + lowering_config = config.lowering_config + assert lowering_config.workgroup_tile_sizes == [128, 320, 0] + assert lowering_config.reduction_tile_sizes == [0, 0, 32] def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: @@ -67,8 +68,8 @@ def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: lowering_config = common.get_lowering_config( tuner_ctx=tuner_ctx, mma_kind=mma_attr, - workgroup=[464, 320, 0], - reduction=[0, 0, 16], + workgroup=[1, 1, 464, 320, 1, 1, 0], + reduction=[0, 0, 0, 0, 0, 0, 16], subgroup_m_count=1, subgroup_n_count=4, ) @@ -79,24 +80,8 @@ def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=1, ) - assert dispatch_parser.ConvParser().get_conv_workgroup_sizes(config) == [ - 1, - 1, - 464, - 320, - 1, - 1, - 0, - ] - assert dispatch_parser.ConvParser().get_conv_reduction_sizes(config) == [ - 0, - 0, - 0, - 0, - 0, - 0, - 16, - ] + assert config.lowering_config.workgroup_tile_sizes == [1, 1, 464, 320, 1, 1, 0] + assert config.lowering_config.reduction_tile_sizes == [0, 0, 0, 0, 0, 0, 16] def test_get_contract_tile_sizes(tuner_ctx: common.TunerContext) -> None: