Skip to content

Commit

Permalink
Pass dialect_stack into get_optional_op
Browse files Browse the repository at this point in the history
  • Loading branch information
alexarice committed Dec 23, 2024
1 parent fa064ee commit f4c6c4d
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 27 deletions.
34 changes: 34 additions & 0 deletions tests/test_mlcontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,40 @@ def test_get_op_unregistered():
assert issubclass(ctx.get_op("test.dummy2"), UnregisteredOp)


def test_get_op_with_dialect_stack():
"""Test `get_op` and `get_optional_op` methods."""
ctx = MLContext()
ctx.load_op(DummyOp)

assert ctx.get_op("dummy", dialect_stack=("test",)) == DummyOp
with pytest.raises(Exception):
_ = ctx.get_op("dummy2", dialect_stack=("test",))

assert ctx.get_optional_op("dummy", dialect_stack=("test",)) == DummyOp
assert ctx.get_optional_op("dummy2", dialect_stack=("test",)) is None


def test_get_op_unregistered_with_dialect_stack():
"""
Test `get_op` and `get_optional_op`
methods with the `allow_unregistered` flag.
"""
ctx = MLContext(allow_unregistered=True)
ctx.load_op(DummyOp)

assert ctx.get_optional_op("dummy", dialect_stack=("test",)) == DummyOp
op_type = ctx.get_optional_op("dummy2", dialect_stack=("test",))
print(op_type)
assert op_type is not None
assert issubclass(op_type, UnregisteredOp)
assert op_type.create().op_name.data == "dummy2"

assert ctx.get_op("dummy", dialect_stack=("test",)) == DummyOp
op_type = ctx.get_op("dummy2", dialect_stack=("test",))
assert issubclass(op_type, UnregisteredOp)
assert op_type.create().op_name.data == "dummy2"


def test_get_attr():
"""Test `get_attr` and `get_optional_attr` methods."""
ctx = MLContext()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ def test_missing_custom_format():
ctx = MLContext()
ctx.load_dialect(Arith)
ctx.load_dialect(Builtin)
ctx.load_op(PlusCustomFormatOp)
ctx.load_op(NoCustomFormatOp)

parser = Parser(ctx, prog)
with pytest.raises(ParseError):
Expand Down
46 changes: 29 additions & 17 deletions xdsl/context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Callable, Iterable
from collections.abc import Callable, Iterable, Sequence
from dataclasses import dataclass, field
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -112,43 +112,55 @@ def load_attr(self, attr: "type[Attribute]") -> None:
raise Exception(f"Attribute {attr.name} has already been loaded")
self._loaded_attrs[attr.name] = attr

def get_optional_op(self, name: str) -> "type[Operation] | None":
"""
Get an operation class from its name if it exists.
If the operation is not registered, return None unless unregistered operations
are allowed in the context, in which case return an UnregisteredOp.
"""
# If the operation is already loaded, returns it.
def _get_known_op(self, name: str) -> "type[Operation] | None":
if name in self._loaded_ops:
return self._loaded_ops[name]

# Otherwise, check if the operation dialect is registered.
if "." in name:
dialect_name, _ = Dialect.split_name(name)
if (
dialect_name in self._registered_dialects
and dialect_name not in self._loaded_dialects
):
self.load_registered_dialect(dialect_name)
return self.get_optional_op(name)
return self._get_known_op(name)

# If the dialect is unregistered, but the context allows unregistered
# operations, return an UnregisteredOp.
def get_optional_op(
self, name: str, *, dialect_stack: Sequence[str] = ()
) -> "type[Operation] | None":
"""
Get an operation class from its name if it exists or is contained in one of the
dialects in the dialect stack.
If the operation is not registered, return None unless unregistered operations
are allowed in the context, in which case return an UnregisteredOp.
"""
# Check if the name is known.
if op_type := self._get_known_op(name):
return op_type

# Check appending each dialect in the dialect stack.
for dialect_name in reversed(dialect_stack):
dialect_and_name = f"{dialect_name}.{name}"
if op_type := self._get_known_op(dialect_and_name):
return op_type

# If the context allows unregistered operations then create an UnregisteredOp
if self.allow_unregistered:
from xdsl.dialects.builtin import UnregisteredOp

op_type = UnregisteredOp.with_name(name)
self._loaded_ops[name] = op_type
return op_type
return None

def get_op(self, name: str) -> "type[Operation]":
def get_op(
self, name: str, *, dialect_stack: Sequence[str] = ()
) -> "type[Operation]":
"""
Get an operation class from its name.
Get an operation class from its name if it exists or is contained in one of the
dialects in the dialect stack.
If the operation is not registered, raise an exception unless unregistered
operations are allowed in the context, in which case return an UnregisteredOp.
"""
if op_type := self.get_optional_op(name):
if op_type := self.get_optional_op(name, dialect_stack=dialect_stack):
return op_type
raise Exception(f"Operation {name} is not registered")

Expand Down
13 changes: 4 additions & 9 deletions xdsl/parser/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,16 +744,11 @@ def _get_op_by_name(self, name: str) -> type[Operation]:
Raises an error if the operation is not registered, and if unregistered
dialects are not allowed.
"""
op_type = self.ctx.get_optional_op(name)
if op_type is not None:
if op_type := self.ctx.get_optional_op(
name, dialect_stack=self._parser_state.dialect_stack
):
return op_type

for dialect_name in reversed(self._parser_state.dialect_stack):
op_type = self.ctx.get_optional_op(f"{dialect_name}.{name}")
if op_type is not None:
return op_type

self.raise_error(f"unregistered operation {name}!")
self.raise_error(f"Operation {name} is not registered")

def _parse_op_result(self) -> tuple[Span, int]:
"""
Expand Down

0 comments on commit f4c6c4d

Please sign in to comment.