diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index 51c2d5114..e7a83a507 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -290,7 +290,7 @@ def strip_compilation_info(input_path: Path) -> str: return result.process_res.stdout -def main(): +def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("input", help="Input mlir file", type=str) parser.add_argument( @@ -349,7 +349,7 @@ def main(): prefetch_shared_memory=args.prefetch_shared_memory_options, no_reduce_shared_memory_bank_conflicts=args.no_reduce_shared_memory_bank_conflicts_options, ) - specs = generate_configs_and_td_specs( + specs: list[ir.Module] = generate_configs_and_td_specs( mlir_module, tuner_ctx, args.limit, @@ -363,9 +363,9 @@ def main(): spec_path = spec_dir / f"{candidate_num}_spec.mlir" spec_dir.mkdir(parents=True, exist_ok=True) with open(spec_path, "w") as f: - local_scope_spec_str = spec.operation.get_asm(use_local_scope=True) + local_scope_spec_str: str = spec.operation.get_asm(use_local_scope=True) f.write(local_scope_spec_str) if __name__ == "__main__": - args = main() + main()