Skip to content

Commit

Permalink
core: Remove some constructors from Builder
Browse files Browse the repository at this point in the history
The constructors `after`, `before`, `at_start`, and `at_end` are removed
as users should use the equivalent `InsertPoint` constructors instead.
These constructors are currently preventing the `Rewriter` to inherit
`Builder`.

stack-info: PR: #3702, branch: math-fehr/stack/6
  • Loading branch information
math-fehr committed Jan 6, 2025
1 parent 614e4f8 commit 63748c2
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 56 deletions.
6 changes: 3 additions & 3 deletions docs/Toy/toy/frontend/ir_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import dataclass
from typing import NoReturn

from xdsl.builder import Builder
from xdsl.builder import Builder, InsertPoint
from xdsl.dialects.builtin import ModuleOp, TensorType, UnrankedTensorType, f64
from xdsl.ir import Block, Region, SSAValue
from xdsl.utils.scoped_dict import ScopedDict
Expand Down Expand Up @@ -74,7 +74,7 @@ def __init__(self):
# We create an empty MLIR module and codegen functions one at a time and
# add them to the module.
self.module = ModuleOp([])
self.builder = Builder.at_end(self.module.body.blocks[0])
self.builder = Builder(InsertPoint.at_end(self.module.body.blocks[0]))

def ir_gen_module(self, module_ast: ModuleAST) -> ModuleOp:
"""
Expand Down Expand Up @@ -147,7 +147,7 @@ def ir_gen_function(self, function_ast: FunctionAST) -> FuncOp:
block = Block(
arg_types=[UnrankedTensorType(f64) for _ in range(len(proto_args))]
)
self.builder = Builder.at_end(block)
self.builder = Builder(InsertPoint.at_end(block))

# Declare all the function arguments in the symbol table.
for name, value in zip(proto_args, block.args):
Expand Down
7 changes: 3 additions & 4 deletions docs/Toy/toy/rewrites/lower_toy_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from itertools import product
from typing import TypeAlias, TypeVar, cast

from xdsl.builder import Builder
from xdsl.builder import Builder, InsertPoint
from xdsl.context import MLContext
from xdsl.dialects import affine, arith, func, memref, printf
from xdsl.dialects.builtin import (
Expand All @@ -31,7 +31,6 @@
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.rewriter import InsertPoint

from ..dialects import toy

Expand Down Expand Up @@ -114,7 +113,7 @@ def build_affine_for(
step,
)
builder.insert(op)
body_builder_fn(Builder.at_end(block), induction_var, rest)
body_builder_fn(Builder(InsertPoint.at_end(block)), induction_var, rest)
return op


Expand Down Expand Up @@ -304,7 +303,7 @@ def impl_loop(nested_builder: Builder, ivs: _ValueRange):
store_op = affine.StoreOp(value_to_store, alloc.memref, ivs)
nested_builder.insert(store_op)

builder = Builder.before(op)
builder = Builder(InsertPoint.before(op))
build_affine_loop_nest_const(
builder, lower_bounds, tensor_type.get_shape(), steps, impl_loop
)
Expand Down
21 changes: 5 additions & 16 deletions tests/test_op_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,11 @@ def test_insertion_point_constructors():
)

assert InsertPoint.at_start(target) == InsertPoint(target, op1)
assert Builder.at_start(target).insertion_point == InsertPoint(target, op1)

assert InsertPoint.at_end(target) == InsertPoint(target, None)
assert Builder.at_end(target).insertion_point == InsertPoint(target, None)

assert InsertPoint.before(op1) == InsertPoint(target, op1)
assert Builder.before(op1).insertion_point == InsertPoint(target, op1)

assert InsertPoint.after(op1) == InsertPoint(target, op2)
assert Builder.after(op1).insertion_point == InsertPoint(target, op2)

assert InsertPoint.before(op2) == InsertPoint(target, op2)
assert Builder.before(op2).insertion_point == InsertPoint(target, op2)

assert InsertPoint.after(op2) == InsertPoint(target, None)
assert Builder.after(op2).insertion_point == InsertPoint(target, None)


def test_builder():
Expand All @@ -44,7 +33,7 @@ def test_builder():
)

block = Block()
b = Builder.at_end(block)
b = Builder(InsertPoint.at_end(block))

x = ConstantOp.from_int_and_width(0, 1)
y = ConstantOp.from_int_and_width(1, 1)
Expand All @@ -65,7 +54,7 @@ def test_builder_insertion_point():
)

block = Block()
b = Builder.at_end(block)
b = Builder(InsertPoint.at_end(block))

x = ConstantOp.from_int_and_width(1, 8)
y = ConstantOp.from_int_and_width(2, 8)
Expand All @@ -85,7 +74,7 @@ def test_builder_create_block():
block1 = Block()
block2 = Block()
target = Region([block1, block2])
builder = Builder.at_start(block1)
builder = Builder(InsertPoint.at_start(block1))

new_block1 = builder.create_block_at_start(target, (i32,))
assert len(new_block1.args) == 1
Expand Down Expand Up @@ -120,7 +109,7 @@ def test_builder_create_block():

def test_builder_listener_op_insert():
block = Block()
b = Builder.at_end(block)
b = Builder(InsertPoint.at_end(block))

x = ConstantOp.from_int_and_width(1, 32)
y = ConstantOp.from_int_and_width(2, 32)
Expand All @@ -144,7 +133,7 @@ def add_op_on_insert(op: Operation):
def test_builder_listener_block_created():
block = Block()
region = Region([block])
b = Builder.at_start(block)
b = Builder(InsertPoint.at_start(block))

created_blocks: list[Block] = []

Expand Down
6 changes: 3 additions & 3 deletions xdsl/backend/riscv/prologue_epilogue_insertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from ordered_set import OrderedSet

from xdsl.builder import Builder
from xdsl.builder import Builder, InsertPoint
from xdsl.context import MLContext
from xdsl.dialects import builtin, riscv, riscv_func
from xdsl.dialects.riscv import (
Expand Down Expand Up @@ -51,7 +51,7 @@ def get_register_size(r: RISCVRegisterType):
return self.flen

# Build the prologue at the beginning of the function.
builder = Builder.at_start(func.body.blocks[0])
builder = Builder(InsertPoint.at_start(func.body.blocks[0]))
sp_register = builder.insert(riscv.GetRegisterOp(Registers.SP))
stack_size = sum(get_register_size(r) for r in used_callee_preserved_registers)
builder.insert(riscv.AddiOp(sp_register, -stack_size, rd=Registers.SP))
Expand All @@ -73,7 +73,7 @@ def get_register_size(r: RISCVRegisterType):
if not isinstance(ret_op, riscv_func.ReturnOp):
continue

builder = Builder.before(ret_op)
builder = Builder(InsertPoint.before(ret_op))
offset = 0
for reg in used_callee_preserved_registers:
if isinstance(reg, IntRegisterType):
Expand Down
26 changes: 3 additions & 23 deletions xdsl/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,26 +54,6 @@ class Builder(BuilderListener):
insertion_point: InsertPoint
"""Operations will be inserted at this location."""

@staticmethod
def before(op: Operation) -> Builder:
"""Creates a builder with the insertion point before an operation."""
return Builder(InsertPoint.before(op))

@staticmethod
def after(op: Operation) -> Builder:
"""Creates a builder with the insertion point after an operation."""
return Builder(InsertPoint.after(op))

@staticmethod
def at_start(block: Block) -> Builder:
"""Creates a builder with the insertion point at the start of a block."""
return Builder(InsertPoint.at_start(block))

@staticmethod
def at_end(block: Block) -> Builder:
"""Creates a builder with the insertion point at the end of a block."""
return Builder(InsertPoint.at_end(block))

def insert(self, op: OperationInvT) -> OperationInvT:
"""Inserts `op` at the current insertion point."""

Expand Down Expand Up @@ -161,7 +141,7 @@ def _region_no_args(func: Callable[[Builder], None]) -> Region:
Generates a single-block region.
"""
block = Block()
builder = Builder.at_end(block)
builder = Builder(InsertPoint.at_end(block))
func(builder)
return Region(block)

Expand All @@ -179,7 +159,7 @@ def _region_args(

def wrapper(func: _CallableRegionFuncType) -> Region:
block = Block(arg_types=input_types)
builder = Builder.at_start(block)
builder = Builder(InsertPoint.at_start(block))

func(builder, block.args)

Expand Down Expand Up @@ -372,7 +352,7 @@ def __init__(self, arg: Builder | Block | Region | None):
if isinstance(arg, Region):
arg = arg.block
if isinstance(arg, Block):
arg = Builder.at_end(arg)
arg = Builder(InsertPoint.at_end(arg))
self._builder = arg

def __enter__(self) -> tuple[BlockArgument, ...]:
Expand Down
8 changes: 4 additions & 4 deletions xdsl/dialects/irdl/pyrdl_to_irdl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from xdsl.builder import Builder
from xdsl.builder import Builder, InsertPoint
from xdsl.dialects.irdl import AnyOp
from xdsl.dialects.irdl.irdl import (
AttributeOp,
Expand Down Expand Up @@ -37,7 +37,7 @@ def op_def_to_irdl(op: type[IRDLOperation]) -> OperationOp:
op_def = op.get_irdl_definition()

block = Block()
builder = Builder.at_end(block)
builder = Builder(InsertPoint.at_end(block))

# Operands
operand_values: list[SSAValue] = []
Expand All @@ -63,7 +63,7 @@ def attr_def_to_irdl(
attr_def = attr.get_irdl_definition()

block = Block()
builder = Builder.at_end(block)
builder = Builder(InsertPoint.at_end(block))

# Parameters
param_values: list[SSAValue] = []
Expand All @@ -77,7 +77,7 @@ def attr_def_to_irdl(
def dialect_to_irdl(dialect: Dialect, name: str) -> DialectOp:
"""Convert a dialect definition to an IRDL dialect definition."""
block = Block()
builder = Builder.at_end(block)
builder = Builder(InsertPoint.at_end(block))

for attribute in dialect.attributes:
if not issubclass(attribute, ParametrizedAttribute):
Expand Down
6 changes: 3 additions & 3 deletions xdsl/frontend/jaxpr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from jax._src.core import ClosedJaxpr

from xdsl.builder import Builder
from xdsl.builder import Builder, InsertPoint
from xdsl.dialects.builtin import FunctionType, ModuleOp, TensorType, f32
from xdsl.dialects.func import FuncOp, ReturnOp
from xdsl.ir import Block, Region
Expand Down Expand Up @@ -36,7 +36,7 @@ def __init__(self):
# We create an empty MLIR module and codegen functions one at a time and
# add them to the module.
self.module = ModuleOp([])
self.builder = Builder.at_end(self.module.body.blocks[0])
self.builder = Builder(InsertPoint.at_end(self.module.body.blocks[0]))

def ir_gen_module(self, jaxpr: ClosedJaxpr) -> ModuleOp:
"""
Expand All @@ -62,7 +62,7 @@ def ir_gen_module(self, jaxpr: ClosedJaxpr) -> ModuleOp:
]

block = Block(arg_types=input_types)
self.builder = Builder.at_end(block)
self.builder = Builder(InsertPoint.at_end(block))

func_type = FunctionType.from_lists(input_types, output_types)

Expand Down

0 comments on commit 63748c2

Please sign in to comment.