Skip to content

Commit

Permalink
interpreter: support a tuple of listeners in Interpreter (#3695)
Browse files Browse the repository at this point in the history
  • Loading branch information
superlopuh authored Jan 6, 2025
1 parent b83ddb2 commit 2ce4059
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 6 deletions.
4 changes: 2 additions & 2 deletions docs/marimo/linalg_snitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ def _(TypedPtr, a_shape, b_shape, c_shape, ctx, mo, riscv_module):
riscv_c_shaped = ShapedArray(TypedPtr.new_float64([0.0] * c_len), c_shape)

riscv_op_counter = OpCounter()
riscv_interpreter = Interpreter(riscv_module, listener=riscv_op_counter)
riscv_interpreter = Interpreter(riscv_module, listeners=(riscv_op_counter,))

register_implementations(riscv_interpreter, ctx, include_wgpu=False, include_onnx=False)

Expand Down Expand Up @@ -549,7 +549,7 @@ def _(
):
snitch_op_counter = OpCounter()
snitch_interpreter = Interpreter(
snitch_stream_module, listener=snitch_op_counter
snitch_stream_module, listeners=(snitch_op_counter,)
)

snitch_c_shaped = ShapedArray(TypedPtr.new_float64([0.0] * c_len), c_shape)
Expand Down
2 changes: 1 addition & 1 deletion tests/interpreters/test_scf_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def false_region():

def test_tracer():
tracer = OpCounter()
interpreter = Interpreter(sum_to_for_op.clone(), listener=tracer)
interpreter = Interpreter(sum_to_for_op.clone(), listeners=(tracer,))
interpreter.register_implementations(ScfFunctions())
interpreter.register_implementations(FuncFunctions())
interpreter.register_implementations(ArithFunctions())
Expand Down
31 changes: 31 additions & 0 deletions tests/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,34 @@ def index_value(

assert i.value_for_attribute(IntegerAttr(1, i32), i32) == 1
assert i.value_for_attribute(IntegerAttr(1, i32), index) == 1


def test_combined_listener():
@dataclass
class DemoListener(Interpreter.Listener):
strings: list[str]
key: str

def will_interpret_op(self, op: Operation, args: PythonValues) -> None:
self.strings.append("will " + self.key)

def did_interpret_op(self, op: Operation, results: PythonValues) -> None:
self.strings.append("did " + self.key)

@dataclass
@register_impls
class TestFunctions(InterpreterFunctions):
@impl(test.TestOp)
def run_test(
self, interpreter: Interpreter, op: test.TestOp, args: PythonValues
) -> PythonValues:
return ()

strings: list[str] = []
da = DemoListener(strings, "A")
db = DemoListener(strings, "B")
interpreter = Interpreter(ModuleOp([]), listeners=(da, db))
interpreter.register_implementations(TestFunctions())
interpreter.run_op(test.TestOp())

assert strings == ["will A", "will B", "did A", "did B"]
8 changes: 5 additions & 3 deletions xdsl/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def did_interpret_op(self, op: Operation, results: PythonValues) -> None: ...
"""
Runtime data associated with an interpreter functions implementation.
"""
listener: Listener = field(default=Listener())
listeners: tuple[Listener, ...] = field(default=())

@property
def symbol_table(self) -> dict[str, Operation]:
Expand Down Expand Up @@ -622,15 +622,17 @@ def _run_op(self, op: Operation, inputs: PythonValues) -> OpImplResult:
raise InterpretationError(
f"Number of operands ({operands_count}) doesn't match the number of inputs ({inputs_count})."
)
self.listener.will_interpret_op(op, inputs)
for listener in self.listeners:
listener.will_interpret_op(op, inputs)
result = self._impls.run(self, op, inputs)
if (results_count := len(op.results)) != (
actual_result_count := len(result.values)
):
raise InterpretationError(
f"Number of operation results ({results_count}) doesn't match the number of implementation results ({actual_result_count})."
)
self.listener.did_interpret_op(op, result.values)
for listener in self.listeners:
listener.did_interpret_op(op, result.values)
return result

def run_op(self, op: Operation | str, inputs: PythonValues = ()) -> PythonValues:
Expand Down

0 comments on commit 2ce4059

Please sign in to comment.