diff --git a/tests/test_rewriter.py b/tests/test_rewriter.py index 38cca8a129..bdb668fd62 100644 --- a/tests/test_rewriter.py +++ b/tests/test_rewriter.py @@ -618,22 +618,11 @@ def transformation(module: ModuleOp, rewriter: Rewriter) -> None: def test_verify_inline_region(): - block = Block() region = Region(Block()) - with pytest.raises( - ValueError, match="Cannot inline region before a block with no parent" - ): - Rewriter.inline_region_before(region, block) - with pytest.raises(ValueError, match="Cannot move region into itself."): Rewriter.inline_region_before(region, region.block) - with pytest.raises( - ValueError, match="Cannot inline region before a block with no parent" - ): - Rewriter.inline_region_after(region, block) - with pytest.raises(ValueError, match="Cannot move region into itself."): Rewriter.inline_region_after(region, region.block) diff --git a/xdsl/builder.py b/xdsl/builder.py index 50ce855d6e..c9b9e7374d 100644 --- a/xdsl/builder.py +++ b/xdsl/builder.py @@ -9,7 +9,7 @@ from xdsl.dialects.builtin import ArrayAttr from xdsl.ir import Attribute, Block, BlockArgument, Operation, OperationInvT, Region -from xdsl.rewriter import InsertPoint, Rewriter +from xdsl.rewriter import BlockInsertPoint, InsertPoint, Rewriter @dataclass(eq=False) @@ -74,21 +74,30 @@ def insert(self, op: OperationInvT) -> OperationInvT: return op - def create_block_before( - self, insert_before: Block, arg_types: Iterable[Attribute] = () + def create_block( + self, insert_point: BlockInsertPoint, arg_types: Iterable[Attribute] ) -> Block: """ - Create a block before `insert_before`, and set - the insertion point at the end of the inserted block. + Create a block at the given location, and set the operation insertion point + at the end of the inserted block. """ block = Block(arg_types=arg_types) - Rewriter.insert_block_before(block, insert_before) + Rewriter.insert_block(block, insert_point) + self.insertion_point = InsertPoint.at_end(block) self.handle_block_creation(block) - return block + def create_block_before( + self, insert_before: Block, arg_types: Iterable[Attribute] = () + ) -> Block: + """ + Create a block before `insert_before`, and set + the insertion point at the end of the inserted block. + """ + return self.create_block(BlockInsertPoint.before(insert_before), arg_types) + def create_block_after( self, insert_after: Block, arg_types: Iterable[Attribute] = () ) -> Block: @@ -96,14 +105,7 @@ def create_block_after( Create a block after `insert_after`, and set the insertion point at the end of the inserted block. """ - - block = Block(arg_types=arg_types) - Rewriter.insert_block_after(block, insert_after) - self.insertion_point = InsertPoint.at_end(block) - - self.handle_block_creation(block) - - return block + return self.create_block(BlockInsertPoint.after(insert_after), arg_types) def create_block_at_start( self, region: Region, arg_types: Iterable[Attribute] = () @@ -112,13 +114,7 @@ def create_block_at_start( Create a block at the start of `region`, and set the insertion point at the end of the inserted block. """ - block = Block(arg_types=arg_types) - region.insert_block(block, 0) - self.insertion_point = InsertPoint.at_end(block) - - self.handle_block_creation(block) - - return block + return self.create_block(BlockInsertPoint.at_start(region), arg_types) def create_block_at_end( self, region: Region, arg_types: Iterable[Attribute] = () @@ -127,13 +123,7 @@ def create_block_at_end( Create a block at the end of `region`, and set the insertion point at the end of the inserted block. """ - block = Block(arg_types=arg_types) - region.add_block(block) - self.insertion_point = InsertPoint.at_end(block) - - self.handle_block_creation(block) - - return block + return self.create_block(BlockInsertPoint.at_end(region), arg_types) @staticmethod def _region_no_args(func: Callable[[Builder], None]) -> Region: diff --git a/xdsl/pattern_rewriter.py b/xdsl/pattern_rewriter.py index ec10ff20c5..340ad51299 100644 --- a/xdsl/pattern_rewriter.py +++ b/xdsl/pattern_rewriter.py @@ -23,7 +23,7 @@ SSAValue, ) from xdsl.irdl import GenericAttrConstraint, base -from xdsl.rewriter import InsertPoint, Rewriter +from xdsl.rewriter import BlockInsertPoint, InsertPoint, Rewriter from xdsl.utils.hints import isa from xdsl.utils.isattr import isattr @@ -351,25 +351,26 @@ def move_region_contents_to_new_regions(self, region: Region) -> Region: self.has_done_action = True return Rewriter.move_region_contents_to_new_regions(region) + def inline_region(self, region: Region, insertion_point: BlockInsertPoint) -> None: + """Move the region blocks to the specified insertion point.""" + self.has_done_action = True + Rewriter.inline_region(region, insertion_point) + def inline_region_before(self, region: Region, target: Block) -> None: """Move the region blocks to an existing region.""" - self.has_done_action = True - Rewriter.inline_region_before(region, target) + self.inline_region(region, BlockInsertPoint.before(target)) def inline_region_after(self, region: Region, target: Block) -> None: """Move the region blocks to an existing region.""" - self.has_done_action = True - Rewriter.inline_region_after(region, target) + self.inline_region(region, BlockInsertPoint.after(target)) def inline_region_at_start(self, region: Region, target: Region) -> None: """Move the region blocks to an existing region.""" - self.has_done_action = True - Rewriter.inline_region_at_start(region, target) + self.inline_region(region, BlockInsertPoint.at_start(target)) def inline_region_at_end(self, region: Region, target: Region) -> None: """Move the region blocks to an existing region.""" - self.has_done_action = True - Rewriter.inline_region_at_end(region, target) + self.inline_region(region, BlockInsertPoint.at_end(target)) class RewritePattern(ABC): diff --git a/xdsl/rewriter.py b/xdsl/rewriter.py index 5154d2d934..a45cc5161a 100644 --- a/xdsl/rewriter.py +++ b/xdsl/rewriter.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Iterable, Sequence from dataclasses import dataclass, field from xdsl.ir import Block, Operation, Region, SSAValue @@ -227,6 +227,19 @@ def inline_block( parent_region.detach_block(source) source.erase() + @staticmethod + def insert_block(block: Block | Iterable[Block], insert_point: BlockInsertPoint): + """ + Insert one or multiple blocks at a given location. + The blocks to insert should be detached from any region. + The insertion point should not be contained in the block to insert. + """ + region = insert_point.region + if insert_point.insert_before is not None: + region.insert_block_before(block, insert_point.insert_before) + else: + region.add_block(block) + @staticmethod def insert_block_after(block: Block | list[Block], target: Block): """ @@ -234,14 +247,7 @@ def insert_block_after(block: Block | list[Block], target: Block): The blocks to insert should be detached from any region. The target block should not be contained in the block to insert. """ - if target.parent is None: - raise Exception("Cannot move a block after a toplevel op") - region = target.parent - block_list = block if isinstance(block, list) else [block] - if len(block_list) == 0: - return - pos = region.get_block_index(target) - region.insert_block(block_list, pos + 1) + Rewriter.insert_block(block, BlockInsertPoint.after(target)) @staticmethod def insert_block_before(block: Block | list[Block], target: Block): @@ -250,12 +256,7 @@ def insert_block_before(block: Block | list[Block], target: Block): The blocks to insert should be detached from any region. The target block should not be contained in the block to insert. """ - if target.parent is None: - raise Exception("Cannot move a block after a toplevel op") - region = target.parent - block_list = block if isinstance(block, list) else [block] - pos = region.get_block_index(target) - region.insert_block(block_list, pos) + Rewriter.insert_block(block, BlockInsertPoint.before(target)) @staticmethod def insert_op( @@ -275,31 +276,30 @@ def move_region_contents_to_new_regions(region: Region) -> Region: region.move_blocks(new_region) return new_region + @staticmethod + def inline_region(region: Region, insertion_point: BlockInsertPoint) -> None: + """Move the region blocks to a given location.""" + if insertion_point.insert_before is not None: + region.move_blocks_before(insertion_point.insert_before) + else: + region.move_blocks(insertion_point.region) + @staticmethod def inline_region_before(region: Region, target: Block) -> None: """Move the region blocks to an existing region, before `target`.""" - region.move_blocks_before(target) + Rewriter.inline_region(region, BlockInsertPoint.before(target)) @staticmethod def inline_region_after(region: Region, target: Block) -> None: """Move the region blocks to an existing region, after `target`.""" - if target.next_block is not None: - Rewriter.inline_region_before(region, target.next_block) - else: - parent_region = target.parent - if parent_region is None: - raise ValueError("Cannot inline region before a block with no parent") - region.move_blocks(region) + Rewriter.inline_region(region, BlockInsertPoint.after(target)) @staticmethod def inline_region_at_start(region: Region, target: Region) -> None: """Move the region blocks to the start of an existing region.""" - if target.first_block is not None: - Rewriter.inline_region_before(region, target.first_block) - else: - Rewriter.inline_region_at_end(region, target) + Rewriter.inline_region(region, BlockInsertPoint.at_start(target)) @staticmethod def inline_region_at_end(region: Region, target: Region) -> None: """Move the region blocks to the end of an existing region.""" - region.move_blocks(target) + Rewriter.inline_region(region, BlockInsertPoint.at_end(target))