Skip to content

Commit

Permalink
core: Refactor some methods using BlockInsertPoint
Browse files Browse the repository at this point in the history
This factors out quite a few methods in the `Builder`, `Rewriter`,
and `PatternRewriter`.

stack-info: PR: #3704, branch: math-fehr/stack/8
  • Loading branch information
math-fehr committed Jan 20, 2025
1 parent 84b071f commit 3125fc1
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 77 deletions.
11 changes: 0 additions & 11 deletions tests/test_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,22 +618,11 @@ def transformation(module: ModuleOp, rewriter: Rewriter) -> None:


def test_verify_inline_region():
block = Block()
region = Region(Block())

with pytest.raises(
ValueError, match="Cannot inline region before a block with no parent"
):
Rewriter.inline_region_before(region, block)

with pytest.raises(ValueError, match="Cannot move region into itself."):
Rewriter.inline_region_before(region, region.block)

with pytest.raises(
ValueError, match="Cannot inline region before a block with no parent"
):
Rewriter.inline_region_after(region, block)

with pytest.raises(ValueError, match="Cannot move region into itself."):
Rewriter.inline_region_after(region, region.block)

Expand Down
48 changes: 19 additions & 29 deletions xdsl/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from xdsl.dialects.builtin import ArrayAttr
from xdsl.ir import Attribute, Block, BlockArgument, Operation, OperationInvT, Region
from xdsl.rewriter import InsertPoint, Rewriter
from xdsl.rewriter import BlockInsertPoint, InsertPoint, Rewriter


@dataclass(eq=False)
Expand Down Expand Up @@ -74,36 +74,38 @@ def insert(self, op: OperationInvT) -> OperationInvT:

return op

def create_block_before(
self, insert_before: Block, arg_types: Iterable[Attribute] = ()
def create_block(
self, insert_point: BlockInsertPoint, arg_types: Iterable[Attribute]
) -> Block:
"""
Create a block before `insert_before`, and set
the insertion point at the end of the inserted block.
Create a block at the given location, and set the operation insertion point
at the end of the inserted block.
"""
block = Block(arg_types=arg_types)
Rewriter.insert_block_before(block, insert_before)
Rewriter.insert_block(block, insert_point)

self.insertion_point = InsertPoint.at_end(block)

self.handle_block_creation(block)

return block

def create_block_before(
self, insert_before: Block, arg_types: Iterable[Attribute] = ()
) -> Block:
"""
Create a block before `insert_before`, and set
the insertion point at the end of the inserted block.
"""
return self.create_block(BlockInsertPoint.before(insert_before), arg_types)

def create_block_after(
self, insert_after: Block, arg_types: Iterable[Attribute] = ()
) -> Block:
"""
Create a block after `insert_after`, and set
the insertion point at the end of the inserted block.
"""

block = Block(arg_types=arg_types)
Rewriter.insert_block_after(block, insert_after)
self.insertion_point = InsertPoint.at_end(block)

self.handle_block_creation(block)

return block
return self.create_block(BlockInsertPoint.after(insert_after), arg_types)

def create_block_at_start(
self, region: Region, arg_types: Iterable[Attribute] = ()
Expand All @@ -112,13 +114,7 @@ def create_block_at_start(
Create a block at the start of `region`, and set
the insertion point at the end of the inserted block.
"""
block = Block(arg_types=arg_types)
region.insert_block(block, 0)
self.insertion_point = InsertPoint.at_end(block)

self.handle_block_creation(block)

return block
return self.create_block(BlockInsertPoint.at_start(region), arg_types)

def create_block_at_end(
self, region: Region, arg_types: Iterable[Attribute] = ()
Expand All @@ -127,13 +123,7 @@ def create_block_at_end(
Create a block at the end of `region`, and set
the insertion point at the end of the inserted block.
"""
block = Block(arg_types=arg_types)
region.add_block(block)
self.insertion_point = InsertPoint.at_end(block)

self.handle_block_creation(block)

return block
return self.create_block(BlockInsertPoint.at_end(region), arg_types)

@staticmethod
def _region_no_args(func: Callable[[Builder], None]) -> Region:
Expand Down
19 changes: 10 additions & 9 deletions xdsl/pattern_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
SSAValue,
)
from xdsl.irdl import GenericAttrConstraint, base
from xdsl.rewriter import InsertPoint, Rewriter
from xdsl.rewriter import BlockInsertPoint, InsertPoint, Rewriter
from xdsl.utils.hints import isa
from xdsl.utils.isattr import isattr

Expand Down Expand Up @@ -351,25 +351,26 @@ def move_region_contents_to_new_regions(self, region: Region) -> Region:
self.has_done_action = True
return Rewriter.move_region_contents_to_new_regions(region)

def inline_region(self, region: Region, insertion_point: BlockInsertPoint) -> None:
"""Move the region blocks to the specified insertion point."""
self.has_done_action = True
Rewriter.inline_region(region, insertion_point)

def inline_region_before(self, region: Region, target: Block) -> None:
"""Move the region blocks to an existing region."""
self.has_done_action = True
Rewriter.inline_region_before(region, target)
self.inline_region(region, BlockInsertPoint.before(target))

def inline_region_after(self, region: Region, target: Block) -> None:
"""Move the region blocks to an existing region."""
self.has_done_action = True
Rewriter.inline_region_after(region, target)
self.inline_region(region, BlockInsertPoint.after(target))

def inline_region_at_start(self, region: Region, target: Region) -> None:
"""Move the region blocks to an existing region."""
self.has_done_action = True
Rewriter.inline_region_at_start(region, target)
self.inline_region(region, BlockInsertPoint.at_start(target))

def inline_region_at_end(self, region: Region, target: Region) -> None:
"""Move the region blocks to an existing region."""
self.has_done_action = True
Rewriter.inline_region_at_end(region, target)
self.inline_region(region, BlockInsertPoint.at_end(target))


class RewritePattern(ABC):
Expand Down
56 changes: 28 additions & 28 deletions xdsl/rewriter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from collections.abc import Sequence
from collections.abc import Iterable, Sequence
from dataclasses import dataclass, field

from xdsl.ir import Block, Operation, Region, SSAValue
Expand Down Expand Up @@ -227,21 +227,27 @@ def inline_block(
parent_region.detach_block(source)
source.erase()

@staticmethod
def insert_block(block: Block | Iterable[Block], insert_point: BlockInsertPoint):
"""
Insert one or multiple blocks at a given location.
The blocks to insert should be detached from any region.
The insertion point should not be contained in the block to insert.
"""
region = insert_point.region
if insert_point.insert_before is not None:
region.insert_block_before(block, insert_point.insert_before)
else:
region.add_block(block)

@staticmethod
def insert_block_after(block: Block | list[Block], target: Block):
"""
Insert one or multiple blocks after another block.
The blocks to insert should be detached from any region.
The target block should not be contained in the block to insert.
"""
if target.parent is None:
raise Exception("Cannot move a block after a toplevel op")
region = target.parent
block_list = block if isinstance(block, list) else [block]
if len(block_list) == 0:
return
pos = region.get_block_index(target)
region.insert_block(block_list, pos + 1)
Rewriter.insert_block(block, BlockInsertPoint.after(target))

@staticmethod
def insert_block_before(block: Block | list[Block], target: Block):
Expand All @@ -250,12 +256,7 @@ def insert_block_before(block: Block | list[Block], target: Block):
The blocks to insert should be detached from any region.
The target block should not be contained in the block to insert.
"""
if target.parent is None:
raise Exception("Cannot move a block after a toplevel op")
region = target.parent
block_list = block if isinstance(block, list) else [block]
pos = region.get_block_index(target)
region.insert_block(block_list, pos)
Rewriter.insert_block(block, BlockInsertPoint.before(target))

@staticmethod
def insert_op(
Expand All @@ -275,31 +276,30 @@ def move_region_contents_to_new_regions(region: Region) -> Region:
region.move_blocks(new_region)
return new_region

@staticmethod
def inline_region(region: Region, insertion_point: BlockInsertPoint) -> None:
"""Move the region blocks to a given location."""
if insertion_point.insert_before is not None:
region.move_blocks_before(insertion_point.insert_before)
else:
region.move_blocks(insertion_point.region)

@staticmethod
def inline_region_before(region: Region, target: Block) -> None:
"""Move the region blocks to an existing region, before `target`."""
region.move_blocks_before(target)
Rewriter.inline_region(region, BlockInsertPoint.before(target))

@staticmethod
def inline_region_after(region: Region, target: Block) -> None:
"""Move the region blocks to an existing region, after `target`."""
if target.next_block is not None:
Rewriter.inline_region_before(region, target.next_block)
else:
parent_region = target.parent
if parent_region is None:
raise ValueError("Cannot inline region before a block with no parent")
region.move_blocks(region)
Rewriter.inline_region(region, BlockInsertPoint.after(target))

@staticmethod
def inline_region_at_start(region: Region, target: Region) -> None:
"""Move the region blocks to the start of an existing region."""
if target.first_block is not None:
Rewriter.inline_region_before(region, target.first_block)
else:
Rewriter.inline_region_at_end(region, target)
Rewriter.inline_region(region, BlockInsertPoint.at_start(target))

@staticmethod
def inline_region_at_end(region: Region, target: Region) -> None:
"""Move the region blocks to the end of an existing region."""
region.move_blocks(target)
Rewriter.inline_region(region, BlockInsertPoint.at_end(target))

0 comments on commit 3125fc1

Please sign in to comment.