diff --git a/sharktank/sharktank/tools/tuner/candidate_gen.py b/sharktank/sharktank/tools/tuner/candidate_gen.py index c86922205..be31f8914 100755 --- a/sharktank/sharktank/tools/tuner/candidate_gen.py +++ b/sharktank/sharktank/tools/tuner/candidate_gen.py @@ -541,6 +541,8 @@ def supports(self, op_name: str) -> bool: return "matmul_transpose_b" in op_name def get_shapes(self, template: list[str]) -> ProblemSize: + mmt_re = None + dps = None for line in template: if "linalg.generic" not in line: continue @@ -585,8 +587,8 @@ def get_shapes(self, template: list[str]) -> ProblemSize: res_type=res_shaped_type, dispatch_kind=DispatchKind.mmt, ) - - assert False, "Shape not found" + assert mmt_re + assert dps, f"'{mmt_re}' not found in given context" def get_transform_function_mmt( self, problem_size: ProblemSize, functionName: str, configuration: Configuration @@ -1271,13 +1273,13 @@ def walk_mlir_op( def tune( - input: str, - output: str = "", - limit: int = 4096, - num_subgroups: int = 4, - lhs_dims: str = "mk", - rhs_dims: str = "nk", - tile_dims: str = "mnk", + input: str, # Path to the mlir file to be tuned + output: str = "", # Path to the output directory, auto creates one if not given + limit: int = 4096, # Max candidates to be generated + num_subgroups: int = 4, # GPU spec, used to determine candidate generation constraints + lhs_dims: str = "mk", # Dimensions for the left-hand side operand in matrix operations + rhs_dims: str = "nk", # Dimensions for the right-hand side operand in matrix operations + tile_dims: str = "mnk", # Dimensions for the tile size ): input_file = str(input)