Skip to content

Commit

Permalink
[Tuner] Fix large model benchmarking (#808)
Browse files Browse the repository at this point in the history
* Add model-specific benchmark timeout
* Fix benchmark argument parsing to allow for `=` in command line
argument values
* Don't print candidate trackers at the very end (too much noise)
* Always promote operands
  • Loading branch information
kuhar authored Jan 10, 2025
1 parent 45322d4 commit b920696
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
11 changes: 5 additions & 6 deletions tuner/examples/simple/simple_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def __init__(self, tuner_context: libtuner.TunerContext):
super().__init__(tuner_context)
self.compile_flags: list[str] = []
self.benchmark_flags: list[str] = []
self.compile_timeout: int = 10
self.compile_timeout: int = 16
self.benchmark_timeout: int = 16

def get_iree_compile_flags(self) -> list[str]:
return self.compile_flags
Expand All @@ -27,7 +28,7 @@ def get_iree_benchmark_module_flags(self) -> list[str]:
return self.benchmark_flags

def get_benchmark_timeout_s(self) -> int:
return 10
return self.benchmark_timeout


def read_flags_file(flags_file: str) -> list[str]:
Expand Down Expand Up @@ -127,7 +128,7 @@ def main():

print("Compiling models with top candidates...")
simple_tuner.compile_flags = compile_flags
simple_tuner.compile_timeout = 60
simple_tuner.compile_timeout = 120
compiled_model_candidates = libtuner.compile(
args,
path_config,
Expand All @@ -141,6 +142,7 @@ def main():

print("Benchmarking compiled model candidates...")
simple_tuner.benchmark_flags = model_benchmark_flags
simple_tuner.benchmark_timeout = 60
top_model_candidates = libtuner.benchmark(
args,
path_config,
Expand All @@ -154,6 +156,3 @@ def main():

print("Check the detailed execution logs in:")
print(path_config.run_log.resolve())

for candidate in candidate_trackers:
libtuner.logging.debug(candidate)
3 changes: 2 additions & 1 deletion tuner/tuner/dispatch_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,10 +309,11 @@ def generate_compilation_infos(
"reduction": reduction_tile_sizes,
"subgroup_m_count": subgroup_m_count,
"subgroup_n_count": subgroup_n_count,
"promote_operands": [0, 1],
}
if codegen_pipeline == iree_codegen.DispatchLoweringPassPipeline.LLVMGPUTileAndFuse:
lowering_config_args["subgroup"] = subgroup_tile_sizes
lowering_config_args["promote_operands"] = [0, 1]

lowering_config = get_lowering_config(**lowering_config_args)

# Create the TranslationInfoAttr
Expand Down
4 changes: 2 additions & 2 deletions tuner/tuner/libtuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,10 +497,10 @@ def run_iree_benchmark_module_command(benchmark_pack: BenchmarkPack):
assert flag[:2] == "--", "iree_benchmark_module_flags should begin with '--'"
split_key_value = flag[2:].split("=")
assert (
len(split_key_value) == 2
len(split_key_value) >= 1
), "iree_benchmark_module_flags should have the format --<key>=<value>"
key = split_key_value[0]
value = split_key_value[1]
value = "=".join(split_key_value[1:])
# Allow the tuning client to pass `--function=@func_name`.
if key == "function":
func_name = value
Expand Down

0 comments on commit b920696

Please sign in to comment.