Skip to content

Commit

Permalink
[tuner]: use property function from iree lowering config python bindi…
Browse files Browse the repository at this point in the history
…ng (#662)

After landing iree-org/iree#19376, all helper
functions related to lowering configuration can be removed. Instead, we
can directly utilize property functions from the LoweringConfig Python
bindings.

This PR is still relevant to the task in
#453: use IREE bindings for
compilation info (incl., lowering_config and translation_info).

---------

Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
  • Loading branch information
bangtianliu authored Dec 9, 2024
1 parent 217690e commit bcc4ad5
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 156 deletions.
84 changes: 39 additions & 45 deletions tuner/tuner/candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,19 @@

tune_logger = logging.getLogger("tune")

# TODO: remove the argument 'workgroup_sizes' and 'reduction_sizes'.

def apply_configuration(
template: list[str],
configuration: Configuration,
workgroup_sizes: list[int],
reduction_sizes: list[int],
) -> str:
intrinsic = get_intrinsic(configuration)
subgroup_m_count = get_subgroup_m_count(configuration)
subgroup_n_count = get_subgroup_n_count(configuration)
lowering_config = configuration.lowering_config
intrinsic = lowering_config.mma_kind
(
subgroup_m_count,
subgroup_n_count,
) = lowering_config.subgroup_count_mn
workgroup_sizes = lowering_config.workgroup_tile_sizes
reduction_sizes = lowering_config.reduction_tile_sizes
tune_logger.info(f"Applying: {configuration}")
expr0 = re.compile(
r"<intrinsic = #iree_gpu\.mma_layout<(.+)>, subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>"
Expand Down Expand Up @@ -125,9 +128,12 @@ class MmtTuner(DispatchTuner, MmtParser):
def get_transform_function_mmt(
self, problem_size: ProblemSize, functionName: str, configuration: Configuration
) -> str:
intrinsic = get_intrinsic(configuration)
subgroup_m_count = get_subgroup_m_count(configuration)
subgroup_n_count = get_subgroup_n_count(configuration)
lowering_config = configuration.lowering_config
intrinsic = lowering_config.mma_kind
(
subgroup_m_count,
subgroup_n_count,
) = lowering_config.subgroup_count_mn

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)
Expand Down Expand Up @@ -167,8 +173,6 @@ def apply_params(
modified += apply_configuration(
template,
configuration,
get_mmt_workgroup_sizes(configuration),
get_mmt_reduction_sizes(configuration),
)
embeddable = indent(
self.get_transform_function_mmt(problem_size, f"match_op", configuration),
Expand All @@ -193,15 +197,12 @@ def get_transform_function_conv(
filter = f"tensor<{problem_size.rhs_type}>"
output = f"tensor<{dynamic_batch_output_ty}>"

workgroup_sizes = ", ".join(
map(str, self.get_conv_workgroup_sizes(configuration))
)
reduction_sizes = ", ".join(
map(str, self.get_conv_reduction_sizes(configuration))
)
intrinsic = get_intrinsic(configuration)
subgroup_m_count = get_subgroup_m_count(configuration)
subgroup_n_count = get_subgroup_n_count(configuration)
lowering_config = configuration.lowering_config
intrinsic = lowering_config.mma_kind
(
subgroup_m_count,
subgroup_n_count,
) = lowering_config.subgroup_count_mn

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)
Expand Down Expand Up @@ -246,8 +247,6 @@ def apply_params(
modified += apply_configuration(
template,
configuration,
self.get_conv_workgroup_sizes(configuration),
self.get_conv_reduction_sizes(configuration),
)
embeddable = indent(
self.get_transform_function_conv(problem_size, f"match_op", configuration),
Expand All @@ -263,15 +262,12 @@ def get_transform_function_broadcast_rhs_mmt(
functionName: str,
configuration: Configuration,
) -> str:
workgroup_sizes = ", ".join(
map(str, get_batch_mmt_workgroup_sizes(configuration))
)
reduction_sizes = ", ".join(
map(str, get_batch_mmt_reduction_sizes(configuration))
)
intrinsic = get_intrinsic(configuration)
subgroup_m_count = get_subgroup_m_count(configuration)
subgroup_n_count = get_subgroup_n_count(configuration)
lowering_config = configuration.lowering_config
intrinsic = lowering_config.mma_kind
(
subgroup_m_count,
subgroup_n_count,
) = lowering_config.subgroup_count_mn

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)
Expand Down Expand Up @@ -316,8 +312,6 @@ def apply_params_broadcast_rhs_mmt(
modified += apply_configuration(
template,
configuration,
get_batch_mmt_workgroup_sizes(configuration),
get_batch_mmt_reduction_sizes(configuration),
)

embeddable = indent(
Expand Down Expand Up @@ -345,8 +339,6 @@ def apply_params(
apply_configuration(
template,
configuration,
get_contract_workgroup_sizes(configuration, self.tile_dims),
get_contract_reduction_sizes(configuration, self.tile_dims),
),
"",
)
Expand All @@ -359,9 +351,12 @@ def get_transform_function_batch_mmt(
functionName: str,
configuration: Configuration,
) -> str:
intrinsic = get_intrinsic(configuration)
subgroup_m_count = get_subgroup_m_count(configuration)
subgroup_n_count = get_subgroup_n_count(configuration)
lowering_config = configuration.lowering_config
intrinsic = lowering_config.mma_kind
(
subgroup_m_count,
subgroup_n_count,
) = lowering_config.subgroup_count_mn

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)
Expand Down Expand Up @@ -403,8 +398,6 @@ def apply_params(
modified += apply_configuration(
template,
configuration,
get_batch_mmt_workgroup_sizes(configuration),
get_batch_mmt_reduction_sizes(configuration),
)

embeddable = indent(
Expand All @@ -428,9 +421,12 @@ def get_transform_function_batch_matmul(
input1 = f"tensor<{problem_size.rhs_type}>"
output = f"tensor<{problem_size.res_type}>"

intrinsic = get_intrinsic(configuration)
subgroup_m_count = get_subgroup_m_count(configuration)
subgroup_n_count = get_subgroup_n_count(configuration)
lowering_config = configuration.lowering_config
intrinsic = lowering_config.mma_kind
(
subgroup_m_count,
subgroup_n_count,
) = lowering_config.subgroup_count_mn

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)
Expand Down Expand Up @@ -476,8 +472,6 @@ def apply_params(
modified += apply_configuration(
template,
configuration,
get_contract_workgroup_sizes(configuration, self.tile_dims),
get_contract_reduction_sizes(configuration, self.tile_dims),
)

embeddable = indent(
Expand Down
36 changes: 18 additions & 18 deletions tuner/tuner/candidate_gen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,15 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None:
'gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true>, {llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}',
]

n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 640
n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 16

mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
lowering_config = common.get_lowering_config(
tuner_ctx=tuner_ctx,
mma_kind=mma_attr,
workgroup=[464, 320, 0],
reduction=[0, 0, 16],
workgroup=[n, oh, ow, oc, fh, fw, 0],
reduction=[0, 0, 0, 0, 0, 0, ic],
subgroup_m_count=1,
subgroup_n_count=4,
)
Expand Down Expand Up @@ -155,7 +155,7 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None:
"LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64"
in modified
)
assert "workgroup = [1, 1, 464, 320, 1, 1, 0]" in modified
assert "workgroup = [2, 64, 64, 640, 3, 3, 0]" in modified
assert "reduction = [0, 0, 0, 0, 0, 0, 16]" in modified
assert (
"gpu_pipeline_options = #iree_gpu.pipeline_options<reorder_workgroups_strategy = <Transpose>>"
Expand Down Expand Up @@ -186,8 +186,8 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None:
lowering_config = common.get_lowering_config(
tuner_ctx=tuner_ctx,
mma_kind=mma_attr,
workgroup=[480, 384, 0],
reduction=[0, 0, 32],
workgroup=[1, 480, 384, 0],
reduction=[0, 0, 0, 32],
subgroup_m_count=1,
subgroup_n_count=4,
)
Expand Down Expand Up @@ -241,8 +241,8 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None:
lowering_config = common.get_lowering_config(
tuner_ctx=tuner_ctx,
mma_kind=mma_attr,
workgroup=[416, 320, 0],
reduction=[0, 0, 128],
workgroup=[1, 416, 320, 0],
reduction=[0, 0, 0, 128],
subgroup_m_count=2,
subgroup_n_count=2,
)
Expand Down Expand Up @@ -299,8 +299,8 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None:
lowering_config = common.get_lowering_config(
tuner_ctx=tuner_ctx,
mma_kind=mma_attr,
workgroup=[128, 64, 0],
reduction=[0, 0, 128],
workgroup=[1, 128, 64, 0],
reduction=[0, 0, 0, 128],
subgroup_m_count=2,
subgroup_n_count=2,
)
Expand Down Expand Up @@ -355,8 +355,8 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None:
lowering_config = common.get_lowering_config(
tuner_ctx=tuner_ctx,
mma_kind=mma_attr,
workgroup=[128, 64, 0],
reduction=[0, 0, 128],
workgroup=[1, 128, 64, 0],
reduction=[0, 0, 0, 128],
subgroup_m_count=2,
subgroup_n_count=2,
)
Expand Down Expand Up @@ -408,8 +408,8 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None:
"%config = transform.param.constant #iree_codegen.compilation_info<"
in embeddable
)
assert "workgroup = [128, 64, 0]" in embeddable
assert "reduction = [0, 0, 128]" in embeddable
assert "workgroup = [1, 128, 64, 0]" in embeddable
assert "reduction = [0, 0, 0, 128]" in embeddable
assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable
assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable

Expand All @@ -435,8 +435,8 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None:
lowering_config = common.get_lowering_config(
tuner_ctx=tuner_ctx,
mma_kind=mma_attr,
workgroup=[128, 64, 0],
reduction=[0, 0, 128],
workgroup=[1, 128, 64, 0],
reduction=[0, 0, 0, 128],
subgroup_m_count=2,
subgroup_n_count=2,
)
Expand Down Expand Up @@ -492,8 +492,8 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None:
"%config = transform.param.constant #iree_codegen.compilation_info<"
in embeddable
)
assert "workgroup = [128, 64, 0]" in embeddable
assert "reduction = [0, 0, 128]" in embeddable
assert "workgroup = [1, 128, 64, 0]" in embeddable
assert "reduction = [0, 0, 0, 128]" in embeddable
assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable
assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable

Expand Down
34 changes: 0 additions & 34 deletions tuner/tuner/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,40 +119,6 @@ class Configuration:
waves_per_eu: int


def get_intrinsic(config: Configuration) -> Optional[iree_gpu.MMAAttr]:
if "mma_kind" in config.lowering_config.attributes:
return config.lowering_config.attributes["mma_kind"]
return None


def get_workgroup_tile_sizes(config: Configuration) -> list[int]:
if "workgroup" in config.lowering_config.attributes:
workgroup_attrs = config.lowering_config.attributes["workgroup"]
return [attr.value for attr in workgroup_attrs]
return []


def get_reduction_tile_sizes(config: Configuration) -> list[int]:
if "reduction" in config.lowering_config.attributes:
reduction_attrs = config.lowering_config.attributes["reduction"]
return [attr.value for attr in reduction_attrs]
return []


def get_subgroup_m_count(config: Configuration) -> Optional[int]:
if "subgroup_m_count" in config.lowering_config.attributes:
attr = config.lowering_config.attributes["subgroup_m_count"]
return attr.value
return None


def get_subgroup_n_count(config: Configuration) -> Optional[int]:
if "subgroup_n_count" in config.lowering_config.attributes:
attr = config.lowering_config.attributes["subgroup_n_count"]
return attr.value
return None


def get_lowering_config(
tuner_ctx: TunerContext,
**kwargs: Any,
Expand Down
5 changes: 2 additions & 3 deletions tuner/tuner/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,5 @@ def test_get_lowering_config(tuner_ctx: common.TunerContext) -> None:
waves_per_eu=2,
)

assert common.get_intrinsic(config) is None
assert common.get_subgroup_m_count(config) == 1
assert common.get_subgroup_n_count(config) == 1
assert config.lowering_config.mma_kind is None
assert config.lowering_config.subgroup_count_mn == (1, 1)
36 changes: 2 additions & 34 deletions tuner/tuner/dispatch_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,10 @@ def parse_tensor_type(tensor_type: str) -> ShapedType:
return ShapedType(shaped_ty.shape, shaped_ty.element_type)


def get_mmt_workgroup_sizes(configuration: Configuration):
return get_workgroup_tile_sizes(configuration)


def get_mmt_reduction_sizes(configuration: Configuration):
return get_reduction_tile_sizes(configuration)


def get_contract_workgroup_sizes(
configuration: Configuration, tile_dims: str
) -> list[int]:
m, n, _k = get_workgroup_tile_sizes(configuration)
m, n, _k = configuration.lowering_config.workgroup_tile_sizes

workgroup_size = [1] * len(tile_dims)
for idx, dim in enumerate(tile_dims):
Expand All @@ -48,7 +40,7 @@ def get_contract_workgroup_sizes(
def get_contract_reduction_sizes(
configuration: Configuration, tile_dims: str
) -> list[int]:
_m, _n, k = get_reduction_tile_sizes(configuration)
_m, _n, k = configuration.lowering_config.reduction_tile_sizes
reduction_size = [0] * len(tile_dims)
for idx, dim in enumerate(tile_dims):
if dim == "k":
Expand All @@ -57,14 +49,6 @@ def get_contract_reduction_sizes(
return reduction_size


def get_batch_mmt_workgroup_sizes(configuration: Configuration) -> list[int]:
return [1] + get_workgroup_tile_sizes(configuration)


def get_batch_mmt_reduction_sizes(configuration: Configuration) -> list[int]:
return [0] + get_reduction_tile_sizes(configuration)


class MlirRegex(Enum):
ssa_value = r"%[a-zA-Z0-9-_]+"
tensor_type = r"tensor<([^>]+)>"
Expand Down Expand Up @@ -164,22 +148,6 @@ class ConvParser(DispatchParser):
def supports(self, op_name: str) -> bool:
return "conv_2d_nhwc_hwcf" in op_name

def get_conv_workgroup_sizes(self, configuration: Configuration) -> list[int]:
batch = 1
fh = 1
fw = 1

oh = 1

ow, oc, _ic = get_workgroup_tile_sizes(configuration)

return [batch, oh, ow, oc, fh, fw, 0]

def get_conv_reduction_sizes(self, configuration: Configuration) -> list[int]:
_ow, _oc, ic = get_reduction_tile_sizes(configuration)

return [0, 0, 0, 0, 0, 0, ic]

def get_shapes(self, template: list[str]) -> ProblemSize:
for line in template:
if "linalg.conv_2d_nhwc_hwcf" not in line:
Expand Down
Loading

0 comments on commit bcc4ad5

Please sign in to comment.