From 63748c2520f967131aed61a661a495b108a66846 Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Sun, 5 Jan 2025 11:32:37 +0000 Subject: [PATCH] core: Remove some constructors from `Builder` The constructors `after`, `before`, `at_start`, and `at_end` are removed as users should use the equivalent `InsertPoint` constructors instead. These constructors are currently preventing the `Rewriter` to inherit `Builder`. stack-info: PR: https://github.com/xdslproject/xdsl/pull/3702, branch: math-fehr/stack/6 --- docs/Toy/toy/frontend/ir_gen.py | 6 ++--- docs/Toy/toy/rewrites/lower_toy_affine.py | 7 +++-- tests/test_op_builder.py | 21 ++++----------- .../riscv/prologue_epilogue_insertion.py | 6 ++--- xdsl/builder.py | 26 +++---------------- xdsl/dialects/irdl/pyrdl_to_irdl.py | 8 +++--- xdsl/frontend/jaxpr/__init__.py | 6 ++--- 7 files changed, 24 insertions(+), 56 deletions(-) diff --git a/docs/Toy/toy/frontend/ir_gen.py b/docs/Toy/toy/frontend/ir_gen.py index 2697f65198..1e6d1048f5 100644 --- a/docs/Toy/toy/frontend/ir_gen.py +++ b/docs/Toy/toy/frontend/ir_gen.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from typing import NoReturn -from xdsl.builder import Builder +from xdsl.builder import Builder, InsertPoint from xdsl.dialects.builtin import ModuleOp, TensorType, UnrankedTensorType, f64 from xdsl.ir import Block, Region, SSAValue from xdsl.utils.scoped_dict import ScopedDict @@ -74,7 +74,7 @@ def __init__(self): # We create an empty MLIR module and codegen functions one at a time and # add them to the module. self.module = ModuleOp([]) - self.builder = Builder.at_end(self.module.body.blocks[0]) + self.builder = Builder(InsertPoint.at_end(self.module.body.blocks[0])) def ir_gen_module(self, module_ast: ModuleAST) -> ModuleOp: """ @@ -147,7 +147,7 @@ def ir_gen_function(self, function_ast: FunctionAST) -> FuncOp: block = Block( arg_types=[UnrankedTensorType(f64) for _ in range(len(proto_args))] ) - self.builder = Builder.at_end(block) + self.builder = Builder(InsertPoint.at_end(block)) # Declare all the function arguments in the symbol table. for name, value in zip(proto_args, block.args): diff --git a/docs/Toy/toy/rewrites/lower_toy_affine.py b/docs/Toy/toy/rewrites/lower_toy_affine.py index c9b4a1afb5..f5adddabd3 100644 --- a/docs/Toy/toy/rewrites/lower_toy_affine.py +++ b/docs/Toy/toy/rewrites/lower_toy_affine.py @@ -8,7 +8,7 @@ from itertools import product from typing import TypeAlias, TypeVar, cast -from xdsl.builder import Builder +from xdsl.builder import Builder, InsertPoint from xdsl.context import MLContext from xdsl.dialects import affine, arith, func, memref, printf from xdsl.dialects.builtin import ( @@ -31,7 +31,6 @@ RewritePattern, op_type_rewrite_pattern, ) -from xdsl.rewriter import InsertPoint from ..dialects import toy @@ -114,7 +113,7 @@ def build_affine_for( step, ) builder.insert(op) - body_builder_fn(Builder.at_end(block), induction_var, rest) + body_builder_fn(Builder(InsertPoint.at_end(block)), induction_var, rest) return op @@ -304,7 +303,7 @@ def impl_loop(nested_builder: Builder, ivs: _ValueRange): store_op = affine.StoreOp(value_to_store, alloc.memref, ivs) nested_builder.insert(store_op) - builder = Builder.before(op) + builder = Builder(InsertPoint.before(op)) build_affine_loop_nest_const( builder, lower_bounds, tensor_type.get_shape(), steps, impl_loop ) diff --git a/tests/test_op_builder.py b/tests/test_op_builder.py index 1fbba7d12c..e0e3e316ea 100644 --- a/tests/test_op_builder.py +++ b/tests/test_op_builder.py @@ -17,22 +17,11 @@ def test_insertion_point_constructors(): ) assert InsertPoint.at_start(target) == InsertPoint(target, op1) - assert Builder.at_start(target).insertion_point == InsertPoint(target, op1) - assert InsertPoint.at_end(target) == InsertPoint(target, None) - assert Builder.at_end(target).insertion_point == InsertPoint(target, None) - assert InsertPoint.before(op1) == InsertPoint(target, op1) - assert Builder.before(op1).insertion_point == InsertPoint(target, op1) - assert InsertPoint.after(op1) == InsertPoint(target, op2) - assert Builder.after(op1).insertion_point == InsertPoint(target, op2) - assert InsertPoint.before(op2) == InsertPoint(target, op2) - assert Builder.before(op2).insertion_point == InsertPoint(target, op2) - assert InsertPoint.after(op2) == InsertPoint(target, None) - assert Builder.after(op2).insertion_point == InsertPoint(target, None) def test_builder(): @@ -44,7 +33,7 @@ def test_builder(): ) block = Block() - b = Builder.at_end(block) + b = Builder(InsertPoint.at_end(block)) x = ConstantOp.from_int_and_width(0, 1) y = ConstantOp.from_int_and_width(1, 1) @@ -65,7 +54,7 @@ def test_builder_insertion_point(): ) block = Block() - b = Builder.at_end(block) + b = Builder(InsertPoint.at_end(block)) x = ConstantOp.from_int_and_width(1, 8) y = ConstantOp.from_int_and_width(2, 8) @@ -85,7 +74,7 @@ def test_builder_create_block(): block1 = Block() block2 = Block() target = Region([block1, block2]) - builder = Builder.at_start(block1) + builder = Builder(InsertPoint.at_start(block1)) new_block1 = builder.create_block_at_start(target, (i32,)) assert len(new_block1.args) == 1 @@ -120,7 +109,7 @@ def test_builder_create_block(): def test_builder_listener_op_insert(): block = Block() - b = Builder.at_end(block) + b = Builder(InsertPoint.at_end(block)) x = ConstantOp.from_int_and_width(1, 32) y = ConstantOp.from_int_and_width(2, 32) @@ -144,7 +133,7 @@ def add_op_on_insert(op: Operation): def test_builder_listener_block_created(): block = Block() region = Region([block]) - b = Builder.at_start(block) + b = Builder(InsertPoint.at_start(block)) created_blocks: list[Block] = [] diff --git a/xdsl/backend/riscv/prologue_epilogue_insertion.py b/xdsl/backend/riscv/prologue_epilogue_insertion.py index e1bf5fb227..85fa00594a 100644 --- a/xdsl/backend/riscv/prologue_epilogue_insertion.py +++ b/xdsl/backend/riscv/prologue_epilogue_insertion.py @@ -2,7 +2,7 @@ from ordered_set import OrderedSet -from xdsl.builder import Builder +from xdsl.builder import Builder, InsertPoint from xdsl.context import MLContext from xdsl.dialects import builtin, riscv, riscv_func from xdsl.dialects.riscv import ( @@ -51,7 +51,7 @@ def get_register_size(r: RISCVRegisterType): return self.flen # Build the prologue at the beginning of the function. - builder = Builder.at_start(func.body.blocks[0]) + builder = Builder(InsertPoint.at_start(func.body.blocks[0])) sp_register = builder.insert(riscv.GetRegisterOp(Registers.SP)) stack_size = sum(get_register_size(r) for r in used_callee_preserved_registers) builder.insert(riscv.AddiOp(sp_register, -stack_size, rd=Registers.SP)) @@ -73,7 +73,7 @@ def get_register_size(r: RISCVRegisterType): if not isinstance(ret_op, riscv_func.ReturnOp): continue - builder = Builder.before(ret_op) + builder = Builder(InsertPoint.before(ret_op)) offset = 0 for reg in used_callee_preserved_registers: if isinstance(reg, IntRegisterType): diff --git a/xdsl/builder.py b/xdsl/builder.py index e995ea0082..6e4c64ea75 100644 --- a/xdsl/builder.py +++ b/xdsl/builder.py @@ -54,26 +54,6 @@ class Builder(BuilderListener): insertion_point: InsertPoint """Operations will be inserted at this location.""" - @staticmethod - def before(op: Operation) -> Builder: - """Creates a builder with the insertion point before an operation.""" - return Builder(InsertPoint.before(op)) - - @staticmethod - def after(op: Operation) -> Builder: - """Creates a builder with the insertion point after an operation.""" - return Builder(InsertPoint.after(op)) - - @staticmethod - def at_start(block: Block) -> Builder: - """Creates a builder with the insertion point at the start of a block.""" - return Builder(InsertPoint.at_start(block)) - - @staticmethod - def at_end(block: Block) -> Builder: - """Creates a builder with the insertion point at the end of a block.""" - return Builder(InsertPoint.at_end(block)) - def insert(self, op: OperationInvT) -> OperationInvT: """Inserts `op` at the current insertion point.""" @@ -161,7 +141,7 @@ def _region_no_args(func: Callable[[Builder], None]) -> Region: Generates a single-block region. """ block = Block() - builder = Builder.at_end(block) + builder = Builder(InsertPoint.at_end(block)) func(builder) return Region(block) @@ -179,7 +159,7 @@ def _region_args( def wrapper(func: _CallableRegionFuncType) -> Region: block = Block(arg_types=input_types) - builder = Builder.at_start(block) + builder = Builder(InsertPoint.at_start(block)) func(builder, block.args) @@ -372,7 +352,7 @@ def __init__(self, arg: Builder | Block | Region | None): if isinstance(arg, Region): arg = arg.block if isinstance(arg, Block): - arg = Builder.at_end(arg) + arg = Builder(InsertPoint.at_end(arg)) self._builder = arg def __enter__(self) -> tuple[BlockArgument, ...]: diff --git a/xdsl/dialects/irdl/pyrdl_to_irdl.py b/xdsl/dialects/irdl/pyrdl_to_irdl.py index ea5ddfabfa..4f06b86caf 100644 --- a/xdsl/dialects/irdl/pyrdl_to_irdl.py +++ b/xdsl/dialects/irdl/pyrdl_to_irdl.py @@ -1,4 +1,4 @@ -from xdsl.builder import Builder +from xdsl.builder import Builder, InsertPoint from xdsl.dialects.irdl import AnyOp from xdsl.dialects.irdl.irdl import ( AttributeOp, @@ -37,7 +37,7 @@ def op_def_to_irdl(op: type[IRDLOperation]) -> OperationOp: op_def = op.get_irdl_definition() block = Block() - builder = Builder.at_end(block) + builder = Builder(InsertPoint.at_end(block)) # Operands operand_values: list[SSAValue] = [] @@ -63,7 +63,7 @@ def attr_def_to_irdl( attr_def = attr.get_irdl_definition() block = Block() - builder = Builder.at_end(block) + builder = Builder(InsertPoint.at_end(block)) # Parameters param_values: list[SSAValue] = [] @@ -77,7 +77,7 @@ def attr_def_to_irdl( def dialect_to_irdl(dialect: Dialect, name: str) -> DialectOp: """Convert a dialect definition to an IRDL dialect definition.""" block = Block() - builder = Builder.at_end(block) + builder = Builder(InsertPoint.at_end(block)) for attribute in dialect.attributes: if not issubclass(attribute, ParametrizedAttribute): diff --git a/xdsl/frontend/jaxpr/__init__.py b/xdsl/frontend/jaxpr/__init__.py index a85e36a60f..510db149f1 100644 --- a/xdsl/frontend/jaxpr/__init__.py +++ b/xdsl/frontend/jaxpr/__init__.py @@ -2,7 +2,7 @@ from jax._src.core import ClosedJaxpr -from xdsl.builder import Builder +from xdsl.builder import Builder, InsertPoint from xdsl.dialects.builtin import FunctionType, ModuleOp, TensorType, f32 from xdsl.dialects.func import FuncOp, ReturnOp from xdsl.ir import Block, Region @@ -36,7 +36,7 @@ def __init__(self): # We create an empty MLIR module and codegen functions one at a time and # add them to the module. self.module = ModuleOp([]) - self.builder = Builder.at_end(self.module.body.blocks[0]) + self.builder = Builder(InsertPoint.at_end(self.module.body.blocks[0])) def ir_gen_module(self, jaxpr: ClosedJaxpr) -> ModuleOp: """ @@ -62,7 +62,7 @@ def ir_gen_module(self, jaxpr: ClosedJaxpr) -> ModuleOp: ] block = Block(arg_types=input_types) - self.builder = Builder.at_end(block) + self.builder = Builder(InsertPoint.at_end(block)) func_type = FunctionType.from_lists(input_types, output_types)