diff --git a/sharktank/sharktank/tools/tuner/README.md b/sharktank/sharktank/tools/tuner/README.md deleted file mode 100644 index 69821496e..000000000 --- a/sharktank/sharktank/tools/tuner/README.md +++ /dev/null @@ -1,67 +0,0 @@ -# IREE dispatch auto-tuning scripts -`libtuner.py` is the core Python script that provides the fundamental functions for the tuning loop. It imports `candidate_gen.py` for candidate generation. To implement the full tuning loop, `libtuner.py` requires a separate Python script that uses the provided `TuningClient` API from `libtuner.py`. - -## Prerequisites -[Optional] Using virtual environments: -```shell -cd tuning -python -m venv .venv -source .venv/bin/activate -``` -Install python dependencies: -```shell -pip install -r ./requirements-tuner.txt -``` -Using the IREE's Python bindings: - - Building with CMake - ```shell - -DIREE_BUILD_PYTHON_BINDINGS=ON \ - -DPython3_EXECUTABLE="$(which python)" - ``` - - Set environment - ```shell - source ../iree-build/.env && export PYTHONPATH - ``` -For more information, refer to the [IREE documentation](https://iree.dev/building-from-source/getting-started/#python-bindings) - -### Overall flow - -1. Symlink all scripts and mlir/irpa files in your build dir. - - Symlink `iree-build-dir/tools` inside `tuning`. - - Symlink ML model MLIR and weights based on `unet.sh`. - -2. Copy the attention/matmul spec as `config.mlir` in the tuning dir. - -3. Temporarily comment out all the existing configs in `config.mlir`. - - Example: - ```mlir - // , @match_mmt_2048x10240x1280 -> @apply_op_config - // , @match_mmt_2048x1280x5120 -> @apply_op_config - // , @match_mmt_2048x1280x1280 -> @apply_op_config - ``` - -4. Compile a baseline unet -```shell -./unet.sh winograd unet.mlir -o unet_baseline.vmfb --iree-hal-dump-executable-files-to=dump-winograd -``` - -5. Find the matmul to tune and copy the `*_benchmark.mlir` file to the build dir. -```shell -cp dump-winograd/*_141_*benchmark.mlir ./141.mlir -``` - -6. Run the tuning script. - - Example: - ```shell - python punet_autotune.py 141.mlir --devices=hip://GPU-0,hip://GPU-4 --num-candidates=1024 - ``` - -7. Check the winner candidate in `result_summary.log`, find and copy the transform spec. - -8. Paste the transform spec into the `config.mlir` and uncomment them. - -9. Add the match function to the entry point in `config.mlir` - - Example: - ```mlir - @match_something -> @apply_op_config - ``` diff --git a/sharktank/sharktank/tools/tuner/candidate_gen.py b/sharktank/sharktank/tools/tuner/candidate_gen.py deleted file mode 100755 index 8a8315afb..000000000 --- a/sharktank/sharktank/tools/tuner/candidate_gen.py +++ /dev/null @@ -1,1408 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -# Given an input dispatch, this code modifies the hyperparameters -# in the code and runs it. - -""" -Generate candidates by tweaking op configuration for tuning. - -It can be invoked in two ways: - 1. From another python script, import and call `tune()` - 2. Run this script directly from the command - -Usage: ./candidate_gen.py 121.mlir -o "tuning/candidates" -l 1024 --lhs-dims=mk --rhs-dims=nk --tile-dims=mnk - -""" - -import argparse -import logging -import math -import pickle -import re -import z3 -from dataclasses import asdict, dataclass -from enum import Enum -from os import mkdir, path, makedirs -from typing import Callable, Optional -from textwrap import indent -from abc import ABC, abstractmethod - -import iree.compiler as ireec -from iree.compiler import ir -from iree.compiler.dialects import _linalg_ops_gen, _util_ops_gen - - -tune_logger = logging.getLogger("tune") - - -class DispatchKind(Enum): - conv = 1 - mmt = 2 - contraction = 3 - batch_mmt = 4 - batch_matmul = 5 - broadcast_rhs_mmt = 6 - - -class ElementType(Enum): - i8 = 1 - i32 = 2 - f8 = 3 - f16 = 4 - f32 = 5 - - @property - def bitwidth(self) -> int: - match self: - case ElementType.i8 | ElementType.f8: - return 8 - case ElementType.f16: - return 16 - case ElementType.i32 | ElementType.f32: - return 32 - case _: - assert False, "unhandled case" - - def __str__(self) -> str: - return self.name - - -@dataclass -class ShapedType: - shape: list[int] - element_type: ElementType - - def rank(self) -> int: - return len(self.shape) - - @property - def bitwidth(self) -> int: - return self.element_type.bitwidth - - def __str__(self) -> str: - dim_to_str = lambda dim: str(dim) if dim != -1 else "?" - return "x".join(map(dim_to_str, self.shape)) + "x" + str(self.element_type) - - -@dataclass -class MatmulSize: - M: int - N: int - K: int - B: int = 1 - - -@dataclass -class ProblemSize: - matmul_size: MatmulSize - lhs_type: ShapedType - rhs_type: ShapedType - res_type: ShapedType - dispatch_kind: DispatchKind - - @property - def MNK(self) -> tuple[int, int, int]: - return (self.matmul_size.M, self.matmul_size.N, self.matmul_size.K) - - -@dataclass -class MfmaIntrinsic: - input_type: ElementType - m: int - n: int - k: int - output_type: ElementType - - def __str__(self) -> str: - input = str(self.input_type).upper() - output = str(self.output_type).upper() - return f"MFMA_{input}_{self.m}x{self.n}x{self.k}_{output}" - - @staticmethod - def mfma_f16_16x16x16_f32(): - return MfmaIntrinsic(ElementType.f16, 16, 16, 16, ElementType.f32) - - @staticmethod - def mfma_f16_32x32x8_f32(): - return MfmaIntrinsic(ElementType.f16, 32, 32, 8, ElementType.f32) - - @staticmethod - def mfma_i8_16x16x32_i32(): - return MfmaIntrinsic(ElementType.i8, 16, 16, 32, ElementType.i32) - - @staticmethod - def mfma_i8_32x32x16_i32(): - return MfmaIntrinsic(ElementType.i8, 32, 32, 16, ElementType.i32) - - @staticmethod - def all(): - return [ - MfmaIntrinsic.mfma_f16_16x16x16_f32(), - MfmaIntrinsic.mfma_f16_32x32x8_f32(), - MfmaIntrinsic.mfma_i8_16x16x32_i32(), - MfmaIntrinsic.mfma_i8_32x32x16_i32(), - ] - - -@dataclass -class Configuration: - subgroup_size: int - workgroup_size: list[int] - intrinsic: MfmaIntrinsic - tile_sizes: list[int] - subgroup_m_count: int - subgroup_n_count: int - waves_per_eu: int - - -class MlirRegex(str, Enum): - ssa_value = r"%[a-zA-Z0-9-_]+" - tensor_type = r"tensor<(([0-9]+x)+((f|i)[0-9]+))>" - - @staticmethod - def dps_ins_two_args() -> str: - return rf"ins\({MlirRegex.ssa_value}, {MlirRegex.ssa_value} : (?P{MlirRegex.tensor_type}), (?P{MlirRegex.tensor_type})\)" - - @staticmethod - def dps_outs_one_arg() -> str: - return rf"outs\({MlirRegex.ssa_value} : (?P{MlirRegex.tensor_type})\)" - - -def read_input_mlir(filename: str) -> list[str]: - with open(filename, "r") as f: - return f.readlines() - - -def get_mmt_tile_sizes(configuration: Configuration): - return configuration.tile_sizes - - -@dataclass -class ConvDimInfo: - n: int - oh: int - ow: int - oc: int - fh: int - fw: int - ic: int - - @staticmethod - def from_rhs_res(rhs_shaped_type: ShapedType, res_shaped_type: ShapedType): - n, oh, ow, oc = res_shaped_type.shape - fh, fw, ic, _ = rhs_shaped_type.shape - return ConvDimInfo(n, oh, ow, oc, fh, fw, ic) - - @staticmethod - def from_problem_size(problem_size: ProblemSize): - return ConvDimInfo.from_rhs_res(problem_size.rhs_type, problem_size.res_type) - - -def get_contract_tile_sizes(configuration: Configuration, tile_dims: str) -> list[int]: - m, n, k = configuration.tile_sizes - tile_size = [1] * len(tile_dims) - for idx, dim in enumerate(tile_dims): - if dim == "m": - tile_size[idx] = m - if dim == "n": - tile_size[idx] = n - if dim == "k": - tile_size[idx] = k - return tile_size - - -def get_batch_mmt_tile_sizes(configuration: Configuration) -> list[int]: - return [1] + configuration.tile_sizes - - -def get_pipeline_config(configuration: Configuration) -> str: - extra_config = ", prefetch_shared_memory" - if configuration.waves_per_eu != 2: - extra_config += f', llvm_func_attrs = {{"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"}}' - return extra_config - - -def apply_configuration( - template: list[str], configuration: Configuration, tile_sizes: list[int] -) -> str: - tune_logger.info(f"Applying: {configuration}") - expr0 = re.compile( - r", subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>" - ) - expr1 = re.compile( - r"LLVMGPUVectorDistribute workgroup_size = \[.+\] subgroup_size = ([0-9]+)," - ) - expr2 = re.compile(r"tile_sizes = \[\[([0-9]+)(, ([0-9]+))+\]\]") - expr3 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"") - repl0 = f", subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>" - repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.workgroup_size))}] subgroup_size = {configuration.subgroup_size},' - repl2 = f'tile_sizes = [[{", ".join(map(str, tile_sizes))}]]' - repl3 = f'"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"' - - new_mlir = "" - for line in template: - if "intrinsic =" in line: - line = re.sub(expr0, repl0, line) - if "LLVMGPUVectorDistribute " in line: - line = re.sub(expr1, repl1, line) - if "tile_sizes" in line: - line = re.sub(expr2, repl2, line) - if "amdgpu-waves-per-eu" in line: - line = re.sub(expr3, repl3, line) - new_mlir += line - - return new_mlir - - -def parse_tensor_type(tensor_type: str) -> ShapedType: - shape_match = re.search(MlirRegex.tensor_type, tensor_type) - assert shape_match - - shape_str = shape_match.group(1) - dims_and_elem = shape_str.split("x") - dims = [int(x) for x in dims_and_elem[:-1]] - elem = dims_and_elem[-1] - str_to_elem_ty = {x.name: x for x in ElementType} - return ShapedType(dims, str_to_elem_ty[elem]) - - -def get_compatible_mfma_intrinsics(problem_size: ProblemSize) -> list[MfmaIntrinsic]: - def is_compatible(intrinsic: MfmaIntrinsic) -> bool: - if problem_size.res_type.element_type != intrinsic.output_type: - return False - if problem_size.dispatch_kind != DispatchKind.batch_matmul: - if problem_size.lhs_type.element_type != intrinsic.input_type: - return False - if problem_size.rhs_type.element_type != intrinsic.input_type: - return False - return True - - return list(filter(is_compatible, MfmaIntrinsic.all())) - - -def get_mfma_intrinsic_constraints( - problem_size: ProblemSize, - intrinsic_m: z3.ArithRef, - intrinsic_n: z3.ArithRef, - intrinsic_k: z3.ArithRef, -) -> z3.BoolRef: - compatible_intrinsics = get_compatible_mfma_intrinsics(problem_size) - assert len(compatible_intrinsics) > 0, "No compatible intrinsics found" - return z3.Or( - *( - z3.And(intrinsic_m == mfma.m, intrinsic_n == mfma.n, intrinsic_k == mfma.k) - for mfma in compatible_intrinsics - ) - ) - - -def get_dispatch_constraints( - problem_size: ProblemSize, - tile_m: z3.ArithRef, - tile_n: z3.ArithRef, - tile_k: z3.ArithRef, -) -> list[z3.BoolRef]: - if problem_size.dispatch_kind != DispatchKind.conv: - return [] - - dim_info = ConvDimInfo.from_problem_size(problem_size) - conv_constraints = [] - # WARNING: This sometimes makes the constraints UNSAT for some reason. - conv_constraints += [tile_m <= dim_info.ow] - conv_constraints += [tile_n <= dim_info.oc] - conv_constraints += [tile_k <= dim_info.ic] - return conv_constraints - - -def calculate_shared_memory_usage_in_bytes( - problem_size: ProblemSize, - m: int | z3.ArithRef, - n: int | z3.ArithRef, - k: int | z3.ArithRef, -) -> int | z3.ArithRef: - lhs_memory = m * k * (problem_size.lhs_type.bitwidth // 8) - rhs_memory = k * n * (problem_size.rhs_type.bitwidth // 8) - return lhs_memory + rhs_memory - - -def generate_constraints( - problem_size: ProblemSize, - tile_sizes, - num_subgroups, - subgroup_size, - intrinsic_size, - workgroup_size, - subgroup_m_count, - subgroup_n_count, - waves_per_eu, -): - M, N, K = ( - problem_size.matmul_size.M, - problem_size.matmul_size.N, - problem_size.matmul_size.K, - ) - m, n, k = tile_sizes - intrinsic_mn, intrinsic_k = intrinsic_size - wg_x, wg_y, wg_z = workgroup_size - wg_threads = z3.Int("wg_threads") - constraints = [wg_threads == wg_x * wg_y * wg_z] - constraints += [subgroup_size == 64, wg_threads <= 1024] - constraints += [ - get_mfma_intrinsic_constraints( - problem_size, intrinsic_mn, intrinsic_mn, intrinsic_k - ) - ] - subgroup_k_count = 1 - constraints += [ - m >= intrinsic_mn, - m <= 512, - m <= M, - ] - constraints += [n >= intrinsic_mn, n <= 512, n <= N, N % n == 0] - constraints += [k >= intrinsic_k, k <= 512, k <= K, K % k == 0] - for x in (subgroup_m_count, subgroup_n_count): - constraints += [x >= 1, x <= 32] - - subgroup_m_tile_count = z3.Int("sg_m_tcnt") - subgroup_n_tile_count = z3.Int("sg_n_tcnt") - subgroup_k_tile_count = z3.Int("sg_k_tcnt") - for x in (subgroup_m_tile_count, subgroup_n_tile_count, subgroup_k_tile_count): - constraints += [x >= 1, x <= 32] - - constraints += [m == subgroup_m_count * subgroup_m_tile_count * intrinsic_mn] - constraints += [n == subgroup_n_count * subgroup_n_tile_count * intrinsic_mn] - constraints += [k == subgroup_k_count * subgroup_k_tile_count * intrinsic_k] - constraints += [wg_x == subgroup_size * subgroup_n_count] - constraints += [wg_y == subgroup_m_count] - constraints += [wg_z == subgroup_k_count] - constraints += [z3.Or(wg_x <= n, wg_x <= m)] - constraints += [k % intrinsic_mn == 0] - constraints += [(k * n) % wg_threads == 0] - constraints += [(k * m) % wg_threads == 0] - subgroups = subgroup_m_count * subgroup_n_count - if num_subgroups > 0: - constraints += [subgroups == num_subgroups] - else: - constraints += [subgroups >= 1, subgroups <= 10] - - constraints += [waves_per_eu == 2] - # constraints += [z3.Or(waves_per_eu == 2, waves_per_eu == 3, waves_per_eu == 4)] - - shared_memory = calculate_shared_memory_usage_in_bytes(problem_size, m, n, k) - constraints += [shared_memory <= 65536] - - constraints += get_dispatch_constraints(problem_size, m, n, k) - - return constraints - - -def generate_solutions(problem_size: ProblemSize, num_subgrups: int): - M, N, K = problem_size.MNK - tune_logger.info(f"{M},{N},{K}") - m, n, k = z3.Int("m"), z3.Int("n"), z3.Int("k") - subgroup_size = z3.Int("subgroup_size") - intrinsic_mn = z3.Int("intrinsic_mn") - intrinsic_k = z3.Int("intrinsic_k") - wg_x, wg_y, wg_z = z3.Int("wg_x"), z3.Int("wg_y"), z3.Int("wg_z") - sg_m_cnt = z3.Int("sg_m_cnt") - sg_n_cnt = z3.Int("sg_n_cnt") - waves_per_eu = z3.Int("waves_per_eu") - all_vars = [ - m, - n, - k, - subgroup_size, - intrinsic_mn, - intrinsic_k, - wg_x, - wg_y, - wg_z, - sg_m_cnt, - sg_n_cnt, - waves_per_eu, - ] - - solver = z3.Solver() - constraints = generate_constraints( - problem_size, - [m, n, k], - num_subgrups, - subgroup_size, - [intrinsic_mn, intrinsic_k], - [wg_x, wg_y, wg_z], - sg_m_cnt, - sg_n_cnt, - waves_per_eu, - ) - solver.add(z3.simplify(z3.And(constraints))) - tune_logger.debug(f"Initial constraints: {solver}") - i = 0 - while solver.check() == z3.sat: - model = solver.model() - lookup = lambda var: model[var].as_long() - - config = Configuration( - lookup(subgroup_size), - [lookup(wg_x), lookup(wg_y), lookup(wg_z)], - MfmaIntrinsic( - problem_size.lhs_type.element_type, - lookup(intrinsic_mn), - lookup(intrinsic_mn), - lookup(intrinsic_k), - problem_size.res_type.element_type, - ), - [lookup(m), lookup(n), lookup(k)], - lookup(sg_m_cnt), - lookup(sg_n_cnt), - lookup(waves_per_eu), - ) - solver.add(z3.simplify(z3.Not(z3.And(list(x == model[x] for x in all_vars))))) - i += 1 - yield config - - -def get_default_output_dir() -> str: - from datetime import datetime - - return "tuning_" + datetime.now().strftime("%Y_%m_%d_%H_%M") - - -def parse_mlir(mlir_text: str) -> ir.Module: - mlir_module = None - with ireec.ir.Context() as context: - try: - mlir_module = ireec.ir.Module.parse(mlir_text) - tune_logger.info("MLIR parsing successful!") - except ireec.ir.MLIRError as e: - tune_logger.error(f"Error parsing MLIR: {e}") - raise RuntimeError(f"Error parsing MLIR: {e}") - - return mlir_module - - -@dataclass -class MLIRTransformation: - """Transformation of MLIR context""" - - template: str - modified: str - embeddable: str - - -class DispatchTuner(ABC): - @abstractmethod - def supports(self, op_name: str) -> bool: - """Check if the tuner can handle the type of operation represented by the input string.""" - pass - - @abstractmethod - def get_shapes(self, template: list[str]) -> ProblemSize: - """Extract problem size of thge operation.""" - pass - - @abstractmethod - def apply_params( - self, - problem_size: ProblemSize, - template: list[str], - configuration: Configuration, - ) -> MLIRTransformation: - """Apply parameter transformations to the operation.""" - pass - - -@dataclass -class OpWalkResult: - was_interrupted: bool = False - dispatch_tuner: Optional[DispatchTuner] = None - - -class DispatchTunerRegistry: - def __init__(self): - self.registry = set() - - def register(self, dispatch_tuners: list[DispatchTuner]) -> None: - for dispatch_tuner in dispatch_tuners: - self.registry.add(dispatch_tuner) - - def validate_translation(self, attrs: list[ir.NamedAttribute]) -> bool: - for attr in attrs: - if (attr.name == "translation_info") and ( - "LLVMGPUVectorDistribute" in str(attr.attr) - ): - return True - assert False, "Translation info not supported" - - def find_handler(self, op_name: str) -> DispatchTuner: - for dispatch_tuner in self.registry: - if dispatch_tuner.supports(op_name): - return dispatch_tuner - assert False, "Dispatch kind not supported" - - -class MmtTuner(DispatchTuner): - def supports(self, op_name: str) -> bool: - return "matmul_transpose_b" in op_name - - def get_shapes(self, template: list[str]) -> ProblemSize: - mmt_re = None - dps = None - for line in template: - if "linalg.generic" not in line: - continue - if r'iterator_types = ["parallel", "parallel", "reduction"]' not in line: - continue - # ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) - mmt_re = rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - dps = re.search(mmt_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == 2 - lhs_M, lhs_K = lhs_shaped_type.shape - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == 2 - rhs_N, rhs_K = rhs_shaped_type.shape - - assert lhs_shaped_type.element_type == rhs_shaped_type.element_type - assert lhs_K == rhs_K - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == 2 - res_M, res_N = res_shaped_type.shape - - assert lhs_M == res_M - assert rhs_N == res_N - - matmul_size = MatmulSize( - lhs_shaped_type.shape[0], - rhs_shaped_type.shape[0], - lhs_shaped_type.shape[1], - ) - return ProblemSize( - matmul_size, - lhs_type=lhs_shaped_type, - rhs_type=rhs_shaped_type, - res_type=res_shaped_type, - dispatch_kind=DispatchKind.mmt, - ) - assert mmt_re - assert dps, f"'{mmt_re}' not found in given context" - - def get_transform_function_mmt( - self, problem_size: ProblemSize, functionName: str, configuration: Configuration - ) -> str: - tile_sizes = ", ".join(map(str, get_mmt_tile_sizes(configuration))) - - wg_x, wg_y, wg_z = configuration.workgroup_size - extra_config = get_pipeline_config(configuration) - - return f""" - transform.named_sequence @{functionName}(%matmul: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ - %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op - %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value - %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value - transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value - transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> - {extra_config}}}> - > -> !transform.any_param - transform.yield %matmul, %config : !transform.any_op, !transform.any_param - }} - """ - - def apply_params( - self, - problem_size: ProblemSize, - template: list[str], - configuration: Configuration, - ) -> MLIRTransformation: - M, N, K = problem_size.MNK - modified = indent( - self.get_transform_function_mmt( - problem_size, f"match_mmt_{M}x{N}x{K}", configuration - ), - "// ", - ) - modified += apply_configuration( - template, configuration, get_mmt_tile_sizes(configuration) - ) - embeddable = indent( - self.get_transform_function_mmt(problem_size, f"match_op", configuration), - " ", - ) - return MLIRTransformation(template, modified, embeddable) - - -class ConvTuner(DispatchTuner): - def supports(self, op_name: str) -> bool: - return "conv_2d_nhwc_hwcf" in op_name - - def get_conv_tile_sizes(self, configuration: Configuration) -> list[int]: - m, n, k = configuration.tile_sizes - batch = 1 - fh = 1 - fw = 1 - - oh = 1 - - oc = n - ow = m - ic = k - return [batch, oh, ow, oc, fh, fw, ic] - - def get_shapes(self, template: list[str]) -> ProblemSize: - for line in template: - if "linalg.conv_2d_nhwc_hwcf" not in line: - continue - - # ins(%19, %20 : tensor<2x34x34x1280xf16>, tensor<3x3x1280x1280xf16>) outs (%27 : tensor<2x32x32x1280xf32>) - conv_re = ( - rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - ) - dps = re.search(conv_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == 4 - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == 4 - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == 4 - - # int64_t n = outputShape[0]; - # int64_t oh = outputShape[1]; - # int64_t ow = outputShape[2]; - # int64_t oc = outputShape[3]; - # int64_t fh = filterShape[0]; - # int64_t fw = filterShape[1]; - # int64_t ic = filterShape[2]; - dim_info = ConvDimInfo.from_rhs_res(rhs_shaped_type, res_shaped_type) - return ProblemSize( - MatmulSize( - M=dim_info.oh * dim_info.ow, - N=dim_info.oc, - K=dim_info.fh * dim_info.fw * dim_info.ic, - B=dim_info.n, - ), - lhs_shaped_type, - rhs_shaped_type, - res_shaped_type, - DispatchKind.conv, - ) - - assert False, "Shape not found" - - # int64_t n = outputShape[0]; - # int64_t oh = outputShape[1]; - # int64_t ow = outputShape[2]; - # int64_t oc = outputShape[3]; - # int64_t fh = filterShape[0]; - # int64_t fw = filterShape[1]; - # int64_t ic = filterShape[2]; - def get_transform_function_conv( - self, problem_size: ProblemSize, functionName: str, configuration: Configuration - ) -> str: - dynamic_batch_input_ty = problem_size.lhs_type - dynamic_batch_input_ty.shape = dynamic_batch_input_ty.shape.copy() - dynamic_batch_input_ty.shape[0] = -1 - - dynamic_batch_output_ty = problem_size.res_type - dynamic_batch_output_ty.shape = dynamic_batch_output_ty.shape.copy() - dynamic_batch_output_ty.shape[0] - 1 - - input = f"tensor<{dynamic_batch_input_ty}>" - filter = f"tensor<{problem_size.rhs_type}>" - output = f"tensor<{dynamic_batch_output_ty}>" - - tile_sizes = ", ".join(map(str, self.get_conv_tile_sizes(configuration))) - - wg_x, wg_y, wg_z = configuration.workgroup_size - extra_config = get_pipeline_config(configuration) - - return f""" - transform.named_sequence @{functionName}(%conv: !transform.any_op {{transform.readonly}}) - -> (!transform.any_op, !transform.any_param) {{ - %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv {{ - ^bb0(%lhs: {input}, %rhs: {filter}, %out: {output}): - %13 = linalg.conv_2d_nhwc_hwcf {{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}} - ins(%lhs, %rhs : {input}, {filter}) - outs(%out : {output}) -> {output} - }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> - {extra_config}}}> - > -> !transform.any_param - transform.yield %conv, %config : !transform.any_op, !transform.any_param - }} - """ - - def apply_params( - self, - problem_size: ProblemSize, - template: list[str], - configuration: Configuration, - ) -> MLIRTransformation: - conv_dims = ConvDimInfo.from_problem_size(problem_size) - modified = indent( - self.get_transform_function_conv( - problem_size, - f"match_conv_2d_nhwc_hwcf_Bx{conv_dims.oh}x{conv_dims.ow}x{conv_dims.oc}x{conv_dims.fh}x{conv_dims.fw}x{conv_dims.ic}", - configuration, - ), - "// ", - ) - modified += apply_configuration( - template, configuration, self.get_conv_tile_sizes(configuration) - ) - embeddable = indent( - self.get_transform_function_conv(problem_size, f"match_op", configuration), - " ", - ) - return MLIRTransformation(template, modified, embeddable) - - -class ContractionTuner(DispatchTuner): - def __init__(self, lhs_dims: str, rhs_dims: str, tile_dims: str): - self.lhs_dims = lhs_dims - self.rhs_dims = rhs_dims - self.tile_dims = tile_dims - - def supports(self, op_name: str) -> bool: - return "matmul_like" in op_name - - def is_broadcast_rhs_mmt_op(self, line: str) -> bool: - if "linalg.generic" not in line: - return False - if ( - r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]' - not in line - ): - return False - if ( - r"indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>" - not in line - ): - return False - return True - - def is_broadcast_rhs_mmt(self, template: list[str]) -> bool: - return any(self.is_broadcast_rhs_mmt_op(line) for line in template) - - def get_shapes_broadcast_rhs_mmt(self, template: list[str]) -> ProblemSize: - for line in template: - if not self.is_broadcast_rhs_mmt_op(line): - continue - - # ins(%11, %12 : tensor<2x1024x1280xi8>, tensor<10240x1280xi8>) outs(%19 : tensor<2x1024x10240xi32>) - bmmt_re = ( - rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - ) - dps = re.search(bmmt_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == 3 - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == 2 - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == 3 - - B0, M0, K0 = lhs_shaped_type.shape - N1, K1 = rhs_shaped_type.shape - B2, M2, N2 = res_shaped_type.shape - assert B0 == B2 - assert M0 == M2 - assert N1 == N2 - assert K0 == K1 - return ProblemSize( - MatmulSize(M0, N1, K0, B0), - lhs_shaped_type, - rhs_shaped_type, - res_shaped_type, - DispatchKind.broadcast_rhs_mmt, - ) - - assert False, "Shape not found" - - def get_shapes(self, template: list[str]) -> ProblemSize: - if self.is_broadcast_rhs_mmt(template): - return self.get_shapes_broadcast_rhs_mmt(template) - - for line in template: - if "linalg.generic" not in line: - continue - if "lowering_config =" not in line: - continue - if '"reduction"' not in line: - continue - - # ins(%7, %8 : tensor<2x1024x1280xf16>, tensor<20x64x1280xf16>) - cont_re = ( - rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - ) - dps = re.search(cont_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == len(self.lhs_dims) - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == len(self.rhs_dims) - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() >= 2 - - M = math.prod( - val if dim == "m" else 1 - for dim, val in zip(self.lhs_dims, lhs_shaped_type.shape) - ) - N = math.prod( - val if dim == "n" else 1 - for dim, val in zip(self.rhs_dims, rhs_shaped_type.shape) - ) - K0 = math.prod( - val if dim == "k" else 1 - for dim, val in zip(self.lhs_dims, lhs_shaped_type.shape) - ) - K1 = math.prod( - val if dim == "k" else 1 - for dim, val in zip(self.rhs_dims, rhs_shaped_type.shape) - ) - assert K0 == K1 - - return ProblemSize( - MatmulSize(M, N, K0), - lhs_type=lhs_shaped_type, - rhs_type=rhs_shaped_type, - res_type=res_shaped_type, - dispatch_kind=DispatchKind.contraction, - ) - - assert False, "Shape not found" - - def get_transform_function_broadcast_rhs_mmt( - self, - problem_size: ProblemSize, - functionName: str, - configuration: Configuration, - ) -> str: - tile_sizes = ", ".join(map(str, get_batch_mmt_tile_sizes(configuration))) - - wg_x, wg_y, wg_z = configuration.workgroup_size - extra_config = get_pipeline_config(configuration) - - lhs_dynamic_batch = problem_size.lhs_type - lhs_dynamic_batch.shape = lhs_dynamic_batch.shape.copy() - lhs_dynamic_batch.shape[0] = -1 - - return f""" -transform.named_sequence @{functionName}(%generic: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ -%mmt = transform.include @match_broadcast_rhs_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op -%lhs = transform.get_operand %generic[0] : (!transform.any_op) -> !transform.any_value -%rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value -transform.iree.match.cast_compatible_type %lhs = tensor<{lhs_dynamic_batch}> : !transform.any_value -transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value -%config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> - {extra_config}}}> - > -> !transform.any_param -transform.yield %generic, %config : !transform.any_op, !transform.any_param -}} -""" - - def apply_params_broadcast_rhs_mmt( - self, - problem_size: ProblemSize, - template: list[str], - configuration: Configuration, - ) -> MLIRTransformation: - M, N, K = problem_size.MNK - modified = indent( - self.get_transform_function_broadcast_rhs_mmt( - problem_size, f"match_broadcast_rhs_mmt_Bx{M}x{N}x{K}", configuration - ), - "// ", - ) - modified += apply_configuration( - template, configuration, get_batch_mmt_tile_sizes(configuration) - ) - - embeddable = indent( - self.get_transform_function_broadcast_rhs_mmt( - problem_size, f"match_op", configuration - ), - " ", - ) - return MLIRTransformation(template, modified, embeddable) - - def apply_params( - self, - problem_size: ProblemSize, - template: list[str], - configuration: Configuration, - ) -> MLIRTransformation: - if self.is_broadcast_rhs_mmt(template): - return self.apply_params_broadcast_rhs_mmt( - problem_size, template, configuration - ) - - # TODO: Generate transform function. - return MLIRTransformation( - template, - apply_configuration( - template, - configuration, - get_contract_tile_sizes(configuration, self.tile_dims), - ), - "", - ) - - -class BatchMmtTuner(DispatchTuner): - def supports(self, op_name: str) -> bool: - return "batch_matmul_transpose_b" in op_name - - def get_shapes(self, template: list[str]) -> ProblemSize: - for line in template: - if "linalg.generic" not in line: - continue - if ( - r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]' - not in line - ): - continue - # ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) - bmmt_re = ( - rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - ) - dps = re.search(bmmt_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == 3 - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == 3 - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == 3 - - B0, M0, K0 = lhs_shaped_type.shape - B1, N1, K1 = rhs_shaped_type.shape - B2, M2, N2 = res_shaped_type.shape - assert B0 == B1 - assert B0 == B2 - assert M0 == M2 - assert N1 == N2 - assert K0 == K1 - return ProblemSize( - MatmulSize(M0, N1, K0, B0), - lhs_shaped_type, - rhs_shaped_type, - res_shaped_type, - DispatchKind.batch_mmt, - ) - - assert False, "Shape not found" - - def get_transform_function_batch_mmt( - self, - problem_size: ProblemSize, - functionName: str, - configuration: Configuration, - ) -> str: - tile_sizes = ", ".join(map(str, get_batch_mmt_tile_sizes(configuration))) - - wg_x, wg_y, wg_z = configuration.workgroup_size - extra_config = get_pipeline_config(configuration) - - return f""" -transform.named_sequence @{functionName}(%generic: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ -%mmt = transform.include @match_batch_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op -%lhs = transform.get_operand %generic[0] : (!transform.any_op) -> !transform.any_value -%rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value -transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value -transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value -%config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> - {extra_config}}}> - > -> !transform.any_param -transform.yield %generic, %config : !transform.any_op, !transform.any_param -}} -""" - - def apply_params( - self, - problem_size: ProblemSize, - template: list[str], - configuration: Configuration, - ) -> MLIRTransformation: - M, N, K = problem_size.MNK - B = problem_size.matmul_size.B - modified = indent( - self.get_transform_function_batch_mmt( - problem_size, f"match_batch_mmt_{B}x{M}x{N}x{K}", configuration - ), - "// ", - ) - modified += apply_configuration( - template, configuration, get_batch_mmt_tile_sizes(configuration) - ) - - embeddable = indent( - self.get_transform_function_batch_mmt( - problem_size, f"match_op", configuration - ), - " ", - ) - return MLIRTransformation(template, modified, embeddable) - - -class BatchMatmulTuner(DispatchTuner): - def __init__(self, lhs_dims: str, rhs_dims: str, tile_dims: str): - self.lhs_dims = lhs_dims - self.rhs_dims = rhs_dims - self.tile_dims = tile_dims - - def supports(self, op_name: str) -> bool: - return "batch_matmul" in op_name - - def get_shapes(self, template: list[str]) -> ProblemSize: - for line in template: - if "linalg.batch_matmul" not in line: - continue - # ins(%9, %10 : tensor<64x72x1280xf16>, tensor<64x1280x1280xf16>) - # outs(%12 : tensor<64x72x1280xf32>) - cont_re = ( - rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - ) - dps = re.search(cont_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == len(self.lhs_dims) - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == len(self.rhs_dims) - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == lhs_shaped_type.rank() - - LHS = lhs_shaped_type.shape - RHS = rhs_shaped_type.shape - RES = res_shaped_type.shape - - B = math.prod( - val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, LHS) - ) - B0 = math.prod( - val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, RHS) - ) - B1 = math.prod( - val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, RES) - ) - M = math.prod( - val if dim == "m" else 1 for dim, val in zip(self.lhs_dims, LHS) - ) - N = math.prod( - val if dim == "n" else 1 for dim, val in zip(self.rhs_dims, RHS) - ) - K0 = math.prod( - val if dim == "k" else 1 for dim, val in zip(self.lhs_dims, LHS) - ) - K1 = math.prod( - val if dim == "k" else 1 for dim, val in zip(self.rhs_dims, RHS) - ) - assert B == B0 and B == B1 - assert K0 == K1 - - return ProblemSize( - MatmulSize(M, N, K0, B), - lhs_type=lhs_shaped_type, - rhs_type=rhs_shaped_type, - res_type=res_shaped_type, - dispatch_kind=DispatchKind.batch_matmul, - ) - - assert False, "Shape not found" - - def get_transform_function_batch_matmul( - self, - problem_size: ProblemSize, - tile_dims: str, - functionName: str, - configuration: Configuration, - ) -> str: - input0 = f"tensor<{problem_size.lhs_type}>" - input1 = f"tensor<{problem_size.rhs_type}>" - output = f"tensor<{problem_size.res_type}>" - - tile_sizes = ", ".join( - map(str, get_contract_tile_sizes(configuration, tile_dims)) - ) - - wg_x, wg_y, wg_z = configuration.workgroup_size - extra_config = get_pipeline_config(configuration) - - return f""" - transform.named_sequence @{functionName}(%batch_matmul: !transform.any_op {{transform.readonly}}) - -> (!transform.any_op, !transform.any_param) {{ - %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %batch_matmul {{ - ^bb0(%lhs: {input0}, %rhs: {input1}, %out: {output}): - %13 = linalg.batch_matmul - ins(%lhs, %rhs : {input0}, {input1}) - outs(%out : {output}) -> {output} - }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> - {extra_config}}}> - > -> !transform.any_param - transform.yield %batch_matmul, %config : !transform.any_op, !transform.any_param - }} - """ - - def apply_params( - self, - problem_size: ProblemSize, - template: list[str], - configuration: Configuration, - ) -> MLIRTransformation: - M, N, K = problem_size.MNK - modified = indent( - self.get_transform_function_batch_matmul( - problem_size, - self.tile_dims, - f"match_batch_matmul_{problem_size.matmul_size.B}x{M}x{N}x{K}", - configuration, - ), - "// ", - ) - modified += apply_configuration( - template, - configuration, - get_contract_tile_sizes(configuration, self.tile_dims), - ) - - embeddable = indent( - self.get_transform_function_batch_matmul( - problem_size, self.tile_dims, f"match_op", configuration - ), - " ", - ) - return MLIRTransformation(template, modified, embeddable) - - -def walk_callback_get_fn( - op: ir.Operation, - walk_result: OpWalkResult, - dispatch_tuner_registry: DispatchTunerRegistry, -) -> ir.WalkResult: - if op.name == "func.func": - dispatch_tuner_registry.validate_translation([a for a in op.opview.attributes]) - if op.name == "util.func": - func_name = str(op.opview.sym_name) - walk_result.was_interrupted = True - walk_result.dispatch_tuner = dispatch_tuner_registry.find_handler(func_name) - return ir.WalkResult.INTERRUPT - return ir.WalkResult.ADVANCE - - -def walk_mlir_op( - mlir_module: ir.Module, - dispatch_tuner_registry: DispatchTunerRegistry, -) -> OpWalkResult: - walk_result = OpWalkResult() - for op in mlir_module.body.operations: - op.walk( - lambda op: walk_callback_get_fn(op, walk_result, dispatch_tuner_registry), - ir.WalkOrder.POST_ORDER, - ) - if walk_result.was_interrupted: - break - return walk_result - - -def tune( - input: str, # Path to the mlir file to be tuned - output: str = "", # Path to the output directory, auto creates one if not given - limit: int = 4096, # Max candidates to be generated - num_subgroups: int = 4, # GPU spec, used to determine candidate generation constraints - lhs_dims: str = "mk", # Dimensions for the left-hand side operand in matrix operations - rhs_dims: str = "nk", # Dimensions for the right-hand side operand in matrix operations - tile_dims: str = "mnk", # Dimensions for the tile size -): - input_file = str(input) - - if not output: - output = get_default_output_dir() - - # Create the directory if it does not exist - makedirs(str(output), exist_ok=True) - - tune_logger.debug(f"Output directory {output}") - tune_logger.debug(f"Processing {input_file}") - mlir_template = read_input_mlir(input_file) - mlir_text = "".join(mlir_template) - - mlir_module = parse_mlir(mlir_text) - # Save the input file as the first candidate. - with open(path.join(output, f"0.mlir"), "w") as f: - f.write(mlir_text) - - dispatch_tuner_registry = DispatchTunerRegistry() - dispatch_tuner_registry.register( - [ - MmtTuner(), - ConvTuner(), - ContractionTuner(lhs_dims, rhs_dims, tile_dims), - BatchMmtTuner(), - BatchMatmulTuner(lhs_dims, rhs_dims, tile_dims), - ] - ) - - walk_result = walk_mlir_op(mlir_module, dispatch_tuner_registry) - - dispatch_tuner = walk_result.dispatch_tuner - problem_size = dispatch_tuner.get_shapes(mlir_template) - tune_logger.debug(str(problem_size)) - configs = [] - for i, config in enumerate(generate_solutions(problem_size, num_subgroups)): - if i >= limit: - break - tune_logger.info(f"Solution #{i+1}: {config}") - configs.append(config) - tf_mlir = dispatch_tuner.apply_params(problem_size, mlir_template, config) - - with open(path.join(output, f"{i+1}.mlir"), "w") as f: - f.write(tf_mlir.modified) - with open(path.join(output, f"{i+1}_config.mlir"), "w") as f: - f.write(tf_mlir.embeddable) - - with open(path.join(output, "configs.pkl"), "wb") as file: - pickle.dump(configs, file) - - tune_logger.info(f"Generated {len(configs)} candidates") - tune_logger.info(f"Configurations .pkl is stored in {output}/configs.pkl") - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("input", help="Input mlir file", type=str) - parser.add_argument( - "-o", "--output", help="Output dir", type=str, default=get_default_output_dir() - ) - parser.add_argument( - "-l", - "--limit", - help="Max number of candidates generated", - type=int, - default=4096, - ) - parser.add_argument( - "--num-subgroups", - help="Number of subgroups per workgroup to use. (-1 == unconstrained)", - type=int, - default=-1, - ) - parser.add_argument( - "--lhs-dims", help="Map of LHS matmul dims", type=str, default="mk" - ) - parser.add_argument( - "--rhs-dims", help="Map of RHS matmul dims", type=str, default="nk" - ) - parser.add_argument( - "--tile-dims", help="Map of tile size matmul dims", type=str, default="mnk" - ) - parser.add_argument( - "--verbose", "-v", action="store_true", help="Enable verbose output to stdout" - ) - - args = parser.parse_args() - tune_logger.setLevel(logging.DEBUG if args.verbose else logging.INFO) - - # Create printing formatter for logging info - formatter = logging.Formatter("%(message)s") - - # Create a handler to print to console - console_handler = logging.StreamHandler() - console_handler.setFormatter(formatter) - tune_logger.addHandler(console_handler) - - # # Optionally, add a file handler to log to a file - # file_handler = logging.FileHandler("tune.log") - # file_handler.setFormatter(formatter) - # tune_logger.addHandler(file_handler) - - tune( - args.input, - args.output, - args.limit, - args.num_subgroups, - args.lhs_dims, - args.rhs_dims, - args.tile_dims, - ) - - -if __name__ == "__main__": - args = main() diff --git a/sharktank/sharktank/tools/tuner/requirements-dev.txt b/sharktank/sharktank/tools/tuner/requirements-dev.txt deleted file mode 100644 index 51d5b9ba0..000000000 --- a/sharktank/sharktank/tools/tuner/requirements-dev.txt +++ /dev/null @@ -1,2 +0,0 @@ -pre-commit==3.8.0 -virtualenv==20.13.0 diff --git a/sharktank/sharktank/tools/tuner/requirements-tuner.txt b/sharktank/sharktank/tools/tuner/requirements-tuner.txt deleted file mode 100644 index f3484c921..000000000 --- a/sharktank/sharktank/tools/tuner/requirements-tuner.txt +++ /dev/null @@ -1,4 +0,0 @@ -pytest==8.2.2 -tqdm==4.66.4 -z3_solver==4.13.0.0 -types-tqdm==4.66.0.20240417 diff --git a/sharktank/tests/tuner/candidate_gen_test.py b/sharktank/tests/tuner/candidate_gen_test.py deleted file mode 100644 index 4fc21aa63..000000000 --- a/sharktank/tests/tuner/candidate_gen_test.py +++ /dev/null @@ -1,814 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -""" -Usage: python -m pytest candidate_gen_test.py -""" - -import pytest -from sharktank.tools.tuner import candidate_gen - - -def test_get_shaped_type_element_bitwidth(): - assert ( - candidate_gen.ShapedType([1024, 2048], candidate_gen.ElementType.i8).bitwidth - == 8 - ) - assert ( - candidate_gen.ShapedType([2048], candidate_gen.ElementType.i32).bitwidth == 32 - ) - assert ( - candidate_gen.ShapedType( - [2048, 512, 384], candidate_gen.ElementType.f8 - ).bitwidth - == 8 - ) - assert ( - candidate_gen.ShapedType([1, 1], candidate_gen.ElementType.f16).bitwidth == 16 - ) - - -def test_get_shaped_type_to_str(): - assert ( - str(candidate_gen.ShapedType([1024, 2048], candidate_gen.ElementType.i8)) - == "1024x2048xi8" - ) - assert ( - str(candidate_gen.ShapedType([1024], candidate_gen.ElementType.f32)) - == "1024xf32" - ) - assert ( - str(candidate_gen.ShapedType([1, 2, 3], candidate_gen.ElementType.f16)) - == "1x2x3xf16" - ) - assert ( - str(candidate_gen.ShapedType([-1, 2, 3], candidate_gen.ElementType.f16)) - == "?x2x3xf16" - ) - - -def test_parse_tensor_type(): - assert candidate_gen.parse_tensor_type( - "tensor<1x2x3xf32>" - ) == candidate_gen.ShapedType([1, 2, 3], candidate_gen.ElementType.f32) - assert candidate_gen.parse_tensor_type( - "tensor<123xi8>" - ) == candidate_gen.ShapedType([123], candidate_gen.ElementType.i8) - - -def test_get_mmt_tile_sizes(): - config = candidate_gen.Configuration( - subgroup_size=0, - workgroup_size=[], - intrinsic="", - tile_sizes=[128, 320, 32], - subgroup_m_count=0, - subgroup_n_count=0, - waves_per_eu=0, - ) - assert candidate_gen.get_mmt_tile_sizes(config) == [128, 320, 32] - - -def test_get_conv_tile_sizes(): - config = candidate_gen.Configuration( - subgroup_size=64, - workgroup_size=[256, 1, 1], - intrinsic="#iree_gpu.mma_layout", - tile_sizes=[464, 320, 16], - subgroup_m_count=1, - subgroup_n_count=4, - waves_per_eu=1, - ) - assert candidate_gen.ConvTuner().get_conv_tile_sizes(config) == [ - 1, - 1, - 464, - 320, - 1, - 1, - 16, - ] - - -def test_get_contract_tile_sizes(): - config = candidate_gen.Configuration( - subgroup_size=32, - workgroup_size=[16, 16, 1], - intrinsic="", - tile_sizes=[4, 8, 16], - subgroup_m_count=1, - subgroup_n_count=1, - waves_per_eu=2, - ) - assert candidate_gen.get_contract_tile_sizes(config, ["m", "n", "k"]) == [4, 8, 16] - assert candidate_gen.get_contract_tile_sizes(config, ["n", "m", "k"]) == [8, 4, 16] - assert candidate_gen.get_contract_tile_sizes(config, ["k", "n", "m"]) == [16, 8, 4] - assert candidate_gen.get_contract_tile_sizes(config, ["k", "k", "k"]) == [ - 16, - 16, - 16, - ] - - -def test_get_pipeline_config(): - config1 = candidate_gen.Configuration( - subgroup_size=32, - workgroup_size=[16, 16, 1], - intrinsic="", - tile_sizes=[4, 8, 16], - subgroup_m_count=1, - subgroup_n_count=1, - waves_per_eu=2, - ) - config2 = candidate_gen.Configuration( - subgroup_size=32, - workgroup_size=[16, 16, 1], - intrinsic="", - tile_sizes=[4, 8, 16], - subgroup_m_count=1, - subgroup_n_count=1, - waves_per_eu=4, - ) - assert candidate_gen.get_pipeline_config(config1) == ", prefetch_shared_memory" - assert ( - candidate_gen.get_pipeline_config(config2) - == ', prefetch_shared_memory, llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' - ) - - -def test_get_shapes_mmt(): - template = [ - r"%18 = tensor.empty() : tensor<2048x1280xf32>", - r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", - r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', - r"^bb0(%in: f16, %in_0: f16, %out: f32):", - ] - assert candidate_gen.MmtTuner().get_shapes(template) == candidate_gen.ProblemSize( - candidate_gen.MatmulSize(2048, 1280, 1280), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.mmt, - ) - - -def test_get_shapes_conv(): - template = [ - r"%7 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%4 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", - r"%8 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, lowering_config = #iree_codegen.lowering_config, strides = dense<1> : vector<2xi64>} ins(%5, %6 : tensor<1x3x34x1280xf16>, tensor<3x3x1280x256xf16>) outs(%7 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", - r"flow.dispatch.tensor.store %8, %2, offsets = [%workgroup_id_z, %workgroup_id_y, 0, %3], sizes = [1, 1, 32, 256], strides = [1, 1, 1, 1] : tensor<1x1x32x256xf32> -> !flow.dispatch.tensor>", - ] - assert candidate_gen.ConvTuner().get_shapes(template) == candidate_gen.ProblemSize( - candidate_gen.MatmulSize(32, 256, 11520), - candidate_gen.ShapedType([1, 3, 34, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([3, 3, 1280, 256], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([1, 1, 32, 256], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.conv, - ) - - -def test_get_shapes_contract(): - template = [ - r"%18 = tensor.empty() : tensor<2048x1280xf32>", - r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", - r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', - r"^bb0(%in: f16, %in_0: f16, %out: f32):", - ] - assert candidate_gen.ContractionTuner("mk", "nk", "mnk").get_shapes( - template - ) == candidate_gen.ProblemSize( - candidate_gen.MatmulSize(2048, 1280, 1280), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.contraction, - ) - - -def test_get_shapes_batch_matmul(): - template = [ - "%10 = linalg.fill ins(%cst : f32) outs(%7 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", - "%11 = linalg.batch_matmul ins(%8, %9 : tensor<1x32x1024xf32>, tensor<1x1024x32xf32>) outs(%10 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", - "flow.dispatch.tensor.store %11, %2, offsets = [%arg0, %arg1, %arg2], sizes = [1, 32, 32], strides = [1, 1, 1] : tensor<1x32x32xf32> -> !flow.dispatch.tensor>", - ] - assert candidate_gen.BatchMatmulTuner("bmk", "bkn", "mnk").get_shapes( - template - ) == candidate_gen.ProblemSize( - candidate_gen.MatmulSize(32, 32, 1024, 1), - candidate_gen.ShapedType([1, 32, 1024], candidate_gen.ElementType.f32), - candidate_gen.ShapedType([1, 1024, 32], candidate_gen.ElementType.f32), - candidate_gen.ShapedType([1, 32, 32], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.batch_matmul, - ) - - -def test_get_shapes_batch_mmt(): - template = [ - r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>", - r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', - r"flow.dispatch.tensor.store %21, %10, offsets = [0, 0, 0], sizes = [2, 4096, 640], strides = [1, 1, 1] : tensor<2x4096x640xf16> -> !flow.dispatch.tensor>", - ] - assert candidate_gen.BatchMmtTuner().get_shapes( - template - ) == candidate_gen.ProblemSize( - candidate_gen.MatmulSize(4096, 640, 640, 2), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), - candidate_gen.DispatchKind.batch_mmt, - ) - - -def test_mfma_intrinsic_to_str(): - assert ( - str(candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32()) - == "MFMA_F16_16x16x16_F32" - ) - assert ( - str(candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32()) - == "MFMA_I8_32x32x16_I32" - ) - - -def test_get_compatible_mfma_intrinsics(): - assert candidate_gen.get_compatible_mfma_intrinsics( - candidate_gen.ProblemSize( - candidate_gen.MatmulSize(2048, 1280, 1280), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.mmt, - ) - ) == [ - candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), - candidate_gen.MfmaIntrinsic.mfma_f16_32x32x8_f32(), - ] - - assert candidate_gen.get_compatible_mfma_intrinsics( - candidate_gen.ProblemSize( - candidate_gen.MatmulSize(2048, 1280, 1280), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.i32), - candidate_gen.DispatchKind.mmt, - ) - ) == [ - candidate_gen.MfmaIntrinsic.mfma_i8_16x16x32_i32(), - candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32(), - ] - - assert candidate_gen.get_compatible_mfma_intrinsics( - candidate_gen.ProblemSize( - candidate_gen.MatmulSize(968, 320, 640, 64), - candidate_gen.ShapedType([64, 968, 640], candidate_gen.ElementType.f32), - candidate_gen.ShapedType([64, 640, 320], candidate_gen.ElementType.f32), - candidate_gen.ShapedType([64, 968, 320], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.batch_matmul, - ) - ) == [ - candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), - candidate_gen.MfmaIntrinsic.mfma_f16_32x32x8_f32(), - ] - - -def test_generate_solutions(): - matmul_size = candidate_gen.MatmulSize(2048, 3840, 1280) - lhs_type = candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16) - rhs_type = candidate_gen.ShapedType([3840, 1280], candidate_gen.ElementType.f16) - res_type = candidate_gen.ShapedType([2048, 3840], candidate_gen.ElementType.f32) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt - ) - configs = candidate_gen.generate_solutions(problem_size, 4) - assert configs is not None - - -def test_calculate_shared_memory_usage_in_bytes(): - matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) - lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt - ) - assert ( - candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128) - == 147456 - ) - - lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.i8) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt - ) - assert ( - candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128) - == 81920 - ) - - rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.i32) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt - ) - assert ( - candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 128, 64, 32) - == 12288 - ) - - -def test_generate_constraints_valid_input(): - matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) - lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt - ) - # Define input parameters as z3 Ints - m, n, k = ( - candidate_gen.z3.Int("m"), - candidate_gen.z3.Int("n"), - candidate_gen.z3.Int("k"), - ) - subgroup_size = candidate_gen.z3.Int("subgroup_size") - intrinsic_mn = candidate_gen.z3.Int("intrinsic_mn") - intrinsic_k = candidate_gen.z3.Int("intrinsic_k") - wg_x, wg_y, wg_z = ( - candidate_gen.z3.Int("wg_x"), - candidate_gen.z3.Int("wg_y"), - candidate_gen.z3.Int("wg_z"), - ) - sg_m_cnt = candidate_gen.z3.Int("sg_m_cnt") - sg_n_cnt = candidate_gen.z3.Int("sg_n_cnt") - waves_per_eu = candidate_gen.z3.Int("waves_per_eu") - - constraints = candidate_gen.generate_constraints( - problem_size, - [m, n, k], - 4, - subgroup_size, - [intrinsic_mn, intrinsic_k], - [wg_x, wg_y, wg_z], - sg_m_cnt, - sg_n_cnt, - waves_per_eu, - ) - - solver = candidate_gen.z3.Solver() - solver.add(constraints) - - # Check if the constraints are satisfiable - assert solver.check() == candidate_gen.z3.sat - - -def test_generate_constraints_invalid_input(): - # Define input parameters that should lead to unsatisfiable constraints - matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) - lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt - ) - m, n, k = ( - candidate_gen.z3.Int("m"), - candidate_gen.z3.Int("n"), - candidate_gen.z3.Int("k"), - ) - subgroup_size = candidate_gen.z3.Int("subgroup_size") - intrinsic_mn = candidate_gen.z3.Int("intrinsic_mn") - intrinsic_k = candidate_gen.z3.Int("intrinsic_k") - wg_x, wg_y, wg_z = ( - candidate_gen.z3.Int("wg_x"), - candidate_gen.z3.Int("wg_y"), - candidate_gen.z3.Int("wg_z"), - ) - sg_m_cnt = candidate_gen.z3.Int("sg_m_cnt") - sg_n_cnt = candidate_gen.z3.Int("sg_n_cnt") - waves_per_eu = candidate_gen.z3.Int("waves_per_eu") - - constraints = candidate_gen.generate_constraints( - problem_size, - [m, n, k], - 4, - subgroup_size, - [intrinsic_mn, intrinsic_k], - [wg_x, wg_y, wg_z], - sg_m_cnt, - sg_n_cnt, - waves_per_eu, - ) - constraints.append(m > 1000) # Adding an additional unsatisfiable constraint - - solver = candidate_gen.z3.Solver() - solver.add(constraints) - - # Check if the constraints are unsatisfiable - assert solver.check() == candidate_gen.z3.unsat - - -def test_apply_params_mmt(): - mlir_template = [ - ", subgroup_m_count = 16, subgroup_n_count = 16>", - "", - '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}', - ] - - M, N, K = 2048, 1280, 1280 - - config = candidate_gen.Configuration( - subgroup_size=16, - workgroup_size=[16, 16, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), - tile_sizes=[8, 8, 8], - subgroup_m_count=16, - subgroup_n_count=16, - waves_per_eu=8, - ) - - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(M, N, K), - candidate_gen.ShapedType([M, K], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([N, K], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([M, N], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.mmt, - ) - tf_mlir = candidate_gen.MmtTuner().apply_params(problem_size, mlir_template, config) - - modified = tf_mlir.modified - embeddable = tf_mlir.embeddable - - assert modified - assert embeddable - assert ( - "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 16, subgroup_n_count = 16" - in modified - ) - assert ( - "LLVMGPUVectorDistribute workgroup_size = [16, 16, 1] subgroup_size = 16" - in modified - ) - assert "tile_sizes = [[8, 8, 8]]" in modified - assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "8"}' in modified - - -def test_apply_params_conv(): - mlir_template = [ - ", subgroup_m_count = 16, subgroup_n_count = 16>", - "", - '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}', - ] - - n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 640 - - config = candidate_gen.Configuration( - subgroup_size=64, - workgroup_size=[256, 1, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), - tile_sizes=[464, 320, 16], - subgroup_m_count=1, - subgroup_n_count=4, - waves_per_eu=2, - ) - - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(oh * ow, oc, fh * fw * ic), - candidate_gen.ShapedType( - [n, oh + 2, ow + 2, oc], candidate_gen.ElementType.f16 - ), - candidate_gen.ShapedType([fh, fw, ic, oc], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([n, oh, ow, oc], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.conv, - ) - tf_mlir = candidate_gen.ConvTuner().apply_params( - problem_size, mlir_template, config - ) - - modified = tf_mlir.modified - embeddable = tf_mlir.embeddable - - assert modified - assert embeddable - assert ( - "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 1, subgroup_n_count = 4" - in modified - ) - assert ( - "LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64" - in modified - ) - assert "tile_sizes = [[1, 1, 464, 320, 1, 1, 16]]" in modified - assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified - - -def test_apply_params_contract(): - mlir_template = [ - ", subgroup_m_count = 2, subgroup_n_count = 2>}>", - "", - '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', - ] - - tile_dims = "*mnk" - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(2048, 3840, 1280), - candidate_gen.ShapedType([2, 1024, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([3, 20, 64, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([3, 2, 20, 1024, 64], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.contraction, - ) - - config = candidate_gen.Configuration( - subgroup_size=64, - workgroup_size=[256, 1, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_32x32x8_f32(), - tile_sizes=[480, 384, 32], - subgroup_m_count=1, - subgroup_n_count=4, - waves_per_eu=2, - ) - - tf_mlir = candidate_gen.ContractionTuner("mk", "nk", tile_dims).apply_params( - problem_size, mlir_template, config - ) - - new_mlir = tf_mlir.modified - - assert new_mlir - assert ( - "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 1, subgroup_n_count = 4" - in new_mlir - ) - assert ( - "LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64" - in new_mlir - ) - assert "tile_sizes = [[1, 480, 384, 32]]" in new_mlir - assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in new_mlir - - -def test_apply_params_batch_matmul(): - mlir_template = [ - ", subgroup_m_count = 4, subgroup_n_count = 1>}>", - "", - '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', - ] - - tile_dims = "bmnk" - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(968, 320, 640, 64), - candidate_gen.ShapedType([64, 968, 640], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([64, 640, 320], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([64, 968, 320], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.batch_matmul, - ) - - config = candidate_gen.Configuration( - subgroup_size=64, - workgroup_size=[128, 2, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_32x32x8_f32(), - tile_sizes=[416, 320, 128], - subgroup_m_count=2, - subgroup_n_count=2, - waves_per_eu=2, - ) - - tf_mlir = candidate_gen.BatchMatmulTuner("mk", "nk", tile_dims).apply_params( - problem_size, mlir_template, config - ) - - modified = tf_mlir.modified - embeddable = tf_mlir.embeddable - - assert modified - assert embeddable - assert ( - "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" - in modified - ) - assert ( - "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" - in modified - ) - assert "tile_sizes = [[1, 416, 320, 128]]" in modified - assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified - - -def test_apply_params_batch_mmt_float(): - mlir_template = [ - ", subgroup_m_count = 4, subgroup_n_count = 1>}>", - "", - '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', - ] - - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(4096, 640, 640, 2), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.batch_mmt, - ) - - config = candidate_gen.Configuration( - subgroup_size=64, - workgroup_size=[128, 2, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), - tile_sizes=[128, 64, 128], - subgroup_m_count=2, - subgroup_n_count=2, - waves_per_eu=2, - ) - - tf_mlir = candidate_gen.BatchMmtTuner().apply_params( - problem_size, mlir_template, config - ) - - modified = tf_mlir.modified - embeddable = tf_mlir.embeddable - - assert embeddable - assert modified - assert ( - "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" - in modified - ) - assert ( - "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" - in modified - ) - assert "tile_sizes = [[1, 128, 64, 128]]" in modified - assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified - - -def test_apply_params_batch_mmt_int(): - mlir_template = [ - ", subgroup_m_count = 4, subgroup_n_count = 1>}>", - "", - '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', - ] - - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(4096, 640, 640, 2), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), - candidate_gen.DispatchKind.batch_mmt, - ) - - config = candidate_gen.Configuration( - subgroup_size=64, - workgroup_size=[128, 2, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32(), - tile_sizes=[128, 64, 128], - subgroup_m_count=2, - subgroup_n_count=2, - waves_per_eu=4, - ) - - tf_mlir = candidate_gen.BatchMmtTuner().apply_params( - problem_size, mlir_template, config - ) - - modified = tf_mlir.modified - embeddable = tf_mlir.embeddable - - assert modified - assert "// transform.named_sequence @match_batch_mmt_2x4096x640x640(" in modified - assert ( - "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" - in modified - ) - assert ( - "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" - in modified - ) - assert "tile_sizes = [[1, 128, 64, 128]]" in modified - assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in modified - - assert embeddable - assert "transform.named_sequence @match_op(" in embeddable - assert ( - "transform.include @match_batch_mmt_i8_i8_i32 failures(propagate)" in embeddable - ) - assert ( - "transform.iree.match.cast_compatible_type %lhs = tensor<2x4096x640xi8> : !transform.any_value" - in embeddable - ) - assert ( - "transform.iree.match.cast_compatible_type %rhs = tensor<2x640x640xi8> : !transform.any_value" - in embeddable - ) - assert ( - "%config = transform.param.constant #iree_codegen.compilation_info<" - in embeddable - ) - assert "tile_sizes = [[1, 128, 64, 128]]" in embeddable - assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable - assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable - - -def test_apply_params_broadcast_rhs_mmt(): - mlir_template = [ - ", subgroup_m_count = 4, subgroup_n_count = 1>}>", - "", - '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', - ] - - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(4096, 640, 640, 2), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([640, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), - candidate_gen.DispatchKind.broadcast_rhs_mmt, - ) - - config = candidate_gen.Configuration( - subgroup_size=64, - workgroup_size=[128, 2, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32(), - tile_sizes=[128, 64, 128], - subgroup_m_count=2, - subgroup_n_count=2, - waves_per_eu=4, - ) - - tf_mlir = candidate_gen.ContractionTuner( - "mk", "nk", "mnk" - ).apply_params_broadcast_rhs_mmt(problem_size, mlir_template, config) - - modified = tf_mlir.modified - embeddable = tf_mlir.embeddable - - assert modified - assert ( - "// transform.named_sequence @match_broadcast_rhs_mmt_Bx4096x640x640(" - in modified - ) - assert ( - "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" - in modified - ) - assert ( - "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" - in modified - ) - assert "tile_sizes = [[1, 128, 64, 128]]" in modified - assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in modified - - assert embeddable - assert "transform.named_sequence @match_op(" in embeddable - assert ( - "transform.include @match_broadcast_rhs_mmt_i8_i8_i32 failures(propagate)" - in embeddable - ) - assert ( - "transform.iree.match.cast_compatible_type %lhs = tensor : !transform.any_value" - in embeddable - ) - assert ( - "transform.iree.match.cast_compatible_type %rhs = tensor<640x640xi8> : !transform.any_value" - in embeddable - ) - assert ( - "%config = transform.param.constant #iree_codegen.compilation_info<" - in embeddable - ) - assert "tile_sizes = [[1, 128, 64, 128]]" in embeddable - assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable - assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable - - -def test_detect_broadcast_rhs_mmt(): - mlir_lines = [ - r"%18 = tensor.empty() : tensor<2x1024x10240xi32>", - r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x1024x10240xi32>) -> tensor<2x1024x10240xi32>", - r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x1024x1280xi8>, tensor<10240x1280xi8>) outs(%19 : tensor<2x1024x10240xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', - ] - assert candidate_gen.ContractionTuner("mk", "nk", "mnk").is_broadcast_rhs_mmt( - mlir_lines - ) - - -def test_parse_mlir(): - mlir_str = r""" - builtin.module { - func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - %0 = arith.mulf %arg0, %arg1 : tensor<4xf32> - return %0 : tensor<4xf32> - } - } - """ - mlir_module = candidate_gen.parse_mlir(mlir_str) - assert mlir_module != None - assert isinstance(mlir_module, candidate_gen.ireec._mlir_libs._mlir.ir.Module) - assert isinstance( - mlir_module.body.operations[0], candidate_gen.ireec.dialects.func.FuncOp - )