diff --git a/docs/marimo/linalg_snitch.py b/docs/marimo/linalg_snitch.py index 3a0c05d6fd..3b49315f23 100644 --- a/docs/marimo/linalg_snitch.py +++ b/docs/marimo/linalg_snitch.py @@ -499,7 +499,7 @@ def _(TypedPtr, a_shape, b_shape, c_shape, ctx, mo, riscv_module): riscv_c_shaped = ShapedArray(TypedPtr.new_float64([0.0] * c_len), c_shape) riscv_op_counter = OpCounter() - riscv_interpreter = Interpreter(riscv_module, listener=riscv_op_counter) + riscv_interpreter = Interpreter(riscv_module, listeners=(riscv_op_counter,)) register_implementations(riscv_interpreter, ctx, include_wgpu=False, include_onnx=False) @@ -549,7 +549,7 @@ def _( ): snitch_op_counter = OpCounter() snitch_interpreter = Interpreter( - snitch_stream_module, listener=snitch_op_counter + snitch_stream_module, listeners=(snitch_op_counter,) ) snitch_c_shaped = ShapedArray(TypedPtr.new_float64([0.0] * c_len), c_shape) diff --git a/tests/interpreters/test_scf_interpreter.py b/tests/interpreters/test_scf_interpreter.py index 442ea8b376..304ce50cef 100644 --- a/tests/interpreters/test_scf_interpreter.py +++ b/tests/interpreters/test_scf_interpreter.py @@ -81,7 +81,7 @@ def false_region(): def test_tracer(): tracer = OpCounter() - interpreter = Interpreter(sum_to_for_op.clone(), listener=tracer) + interpreter = Interpreter(sum_to_for_op.clone(), listeners=(tracer,)) interpreter.register_implementations(ScfFunctions()) interpreter.register_implementations(FuncFunctions()) interpreter.register_implementations(ArithFunctions()) diff --git a/tests/test_interpreter.py b/tests/test_interpreter.py index f4295853d3..6875d6dc23 100644 --- a/tests/test_interpreter.py +++ b/tests/test_interpreter.py @@ -243,3 +243,34 @@ def index_value( assert i.value_for_attribute(IntegerAttr(1, i32), i32) == 1 assert i.value_for_attribute(IntegerAttr(1, i32), index) == 1 + + +def test_combined_listener(): + @dataclass + class DemoListener(Interpreter.Listener): + strings: list[str] + key: str + + def will_interpret_op(self, op: Operation, args: PythonValues) -> None: + self.strings.append("will " + self.key) + + def did_interpret_op(self, op: Operation, results: PythonValues) -> None: + self.strings.append("did " + self.key) + + @dataclass + @register_impls + class TestFunctions(InterpreterFunctions): + @impl(test.TestOp) + def run_test( + self, interpreter: Interpreter, op: test.TestOp, args: PythonValues + ) -> PythonValues: + return () + + strings: list[str] = [] + da = DemoListener(strings, "A") + db = DemoListener(strings, "B") + interpreter = Interpreter(ModuleOp([]), listeners=(da, db)) + interpreter.register_implementations(TestFunctions()) + interpreter.run_op(test.TestOp()) + + assert strings == ["will A", "will B", "did A", "did B"] diff --git a/xdsl/interpreter.py b/xdsl/interpreter.py index e192b8eed2..859157a5ff 100644 --- a/xdsl/interpreter.py +++ b/xdsl/interpreter.py @@ -560,7 +560,7 @@ def did_interpret_op(self, op: Operation, results: PythonValues) -> None: ... """ Runtime data associated with an interpreter functions implementation. """ - listener: Listener = field(default=Listener()) + listeners: tuple[Listener, ...] = field(default=()) @property def symbol_table(self) -> dict[str, Operation]: @@ -622,7 +622,8 @@ def _run_op(self, op: Operation, inputs: PythonValues) -> OpImplResult: raise InterpretationError( f"Number of operands ({operands_count}) doesn't match the number of inputs ({inputs_count})." ) - self.listener.will_interpret_op(op, inputs) + for listener in self.listeners: + listener.will_interpret_op(op, inputs) result = self._impls.run(self, op, inputs) if (results_count := len(op.results)) != ( actual_result_count := len(result.values) @@ -630,7 +631,8 @@ def _run_op(self, op: Operation, inputs: PythonValues) -> OpImplResult: raise InterpretationError( f"Number of operation results ({results_count}) doesn't match the number of implementation results ({actual_result_count})." ) - self.listener.did_interpret_op(op, result.values) + for listener in self.listeners: + listener.did_interpret_op(op, result.values) return result def run_op(self, op: Operation | str, inputs: PythonValues = ()) -> PythonValues: