From 955600de543f9ec8f3ccb7e43865a85eef3ec78c Mon Sep 17 00:00:00 2001 From: Max Manainen Date: Sat, 21 Dec 2024 12:36:31 +0100 Subject: [PATCH] core: Add attributes to func outputs (#3661) Mlir [supports](https://mlir.llvm.org/docs/Dialects/Func/#funcfunc-funcfuncop) attributes for output types. We need for cross compatibility with other mlir projects. --- tests/filecheck/dialects/func/func_ops.mlir | 10 ++++ .../dialects/func/func_ops_generic.mlir | 10 ++++ .../with-mlir/dialects/func/func_ops.mlir | 10 ++++ xdsl/dialects/arm_func.py | 18 +++--- xdsl/dialects/csl/csl.py | 33 +++++----- xdsl/dialects/func.py | 3 + xdsl/dialects/riscv_func.py | 16 +++-- xdsl/dialects/utils/format.py | 60 ++++++++++++++++--- 8 files changed, 115 insertions(+), 45 deletions(-) diff --git a/tests/filecheck/dialects/func/func_ops.mlir b/tests/filecheck/dialects/func/func_ops.mlir index 608d4f2afb..54563a3b18 100644 --- a/tests/filecheck/dialects/func/func_ops.mlir +++ b/tests/filecheck/dialects/func/func_ops.mlir @@ -71,4 +71,14 @@ builtin.module { // CHECK: func.func public @arg_attrs(%{{.*}} : tensor<8x8xf64> {"llvm.noalias"}, %{{.*}} : tensor<8x8xf64> {"llvm.noalias"}, %{{.*}} : tensor<8x8xf64> {"llvm.noalias"}) -> tensor<8x8xf64> { // CHECK-NEXT: return %{{.*}} : tensor<8x8xf64> // CHECK-NEXT: } + + func.func @output_attributes() -> (f32 {dialect.a = 0 : i32}, f32 {dialect.b = 0 : i32, dialect.c = 1 : i64}) { + %r1, %r2 = "test.op"() : () -> (f32, f32) + return %r1, %r2 : f32, f32 + } + + // CHECK: func.func @output_attributes() -> (f32 {"dialect.a" = 0 : i32}, f32 {"dialect.b" = 0 : i32, "dialect.c" = 1 : i64}) { + // CHECK-NEXT: %r1, %r2 = "test.op"() : () -> (f32, f32) + // CHECK-NEXT: func.return %r1, %r2 : f32, f32 + // CHECK-NEXT: } } diff --git a/tests/filecheck/dialects/func/func_ops_generic.mlir b/tests/filecheck/dialects/func/func_ops_generic.mlir index d35b82177c..471a0b117a 100644 --- a/tests/filecheck/dialects/func/func_ops_generic.mlir +++ b/tests/filecheck/dialects/func/func_ops_generic.mlir @@ -10,3 +10,13 @@ // CHECK-NEXT: ^0(%arg0 : tensor<8x8xf64>, %arg1 : tensor<8x8xf64>): // CHECK-NEXT: "func.return"(%arg0, %arg1) : (tensor<8x8xf64>, tensor<8x8xf64>) -> () // CHECK-NEXT: }) : () -> () + +func.func @output_attributes() -> (f32 {dialect.a = 0 : i32}, f32 {dialect.b = 0 : i32, dialect.c = 1 : i64}) { + %r1, %r2 = "test.op"() : () -> (f32, f32) + return %r1, %r2 : f32, f32 +} + +// CHECK: "func.func"() <{"sym_name" = "output_attributes", "function_type" = () -> (f32, f32), "res_attrs" = [{"dialect.a" = 0 : i32}, {"dialect.b" = 0 : i32, "dialect.c" = 1 : i64}]}> ({ +// CHECK-NEXT: %r1, %r2 = "test.op"() : () -> (f32, f32) +// CHECK-NEXT: "func.return"(%r1, %r2) : (f32, f32) -> () +// CHECK-NEXT: }) : () -> () diff --git a/tests/filecheck/mlir-conversion/with-mlir/dialects/func/func_ops.mlir b/tests/filecheck/mlir-conversion/with-mlir/dialects/func/func_ops.mlir index cdfda54a88..c916c0bc72 100644 --- a/tests/filecheck/mlir-conversion/with-mlir/dialects/func/func_ops.mlir +++ b/tests/filecheck/mlir-conversion/with-mlir/dialects/func/func_ops.mlir @@ -72,4 +72,14 @@ builtin.module { // CHECK: func.func public @arg_attrs(%{{.*}}: tensor<8x8xf64> {"llvm.noalias"}, %{{.*}}: tensor<8x8xf64> {"llvm.noalias"}, %{{.*}}: tensor<8x8xf64> {"llvm.noalias"}) -> tensor<8x8xf64> { // CHECK-NEXT: func.return %{{.*}} : tensor<8x8xf64> // CHECK-NEXT: } + + func.func @output_attributes() -> (f32 {dialect.a = 0 : i32}, f32 {dialect.b = 0 : i32, dialect.c = 1 : i64}) { + %r1, %r2 = "test.op"() : () -> (f32, f32) + return %r1, %r2 : f32, f32 + } + + // CHECK: func.func @output_attributes() -> (f32 {"dialect.a" = 0 : i32}, f32 {"dialect.b" = 0 : i32, "dialect.c" = 1 : i64}) { + // CHECK-NEXT: %0, %1 = "test.op"() : () -> (f32, f32) + // CHECK-NEXT: func.return %0, %1 : f32, f32 + // CHECK-NEXT: } } diff --git a/xdsl/dialects/arm_func.py b/xdsl/dialects/arm_func.py index 7b4feb24bc..e05f28b106 100644 --- a/xdsl/dialects/arm_func.py +++ b/xdsl/dialects/arm_func.py @@ -83,18 +83,16 @@ def __init__( @classmethod def parse(cls, parser: Parser) -> FuncOp: visibility = parser.parse_optional_visibility_keyword() - ( - name, - input_types, - return_types, - region, - extra_attrs, - arg_attrs, - ) = parse_func_op_like( - parser, reserved_attr_names=("sym_name", "function_type", "sym_visibility") + (name, input_types, return_types, region, extra_attrs, arg_attrs, res_attrs) = ( + parse_func_op_like( + parser, + reserved_attr_names=("sym_name", "function_type", "sym_visibility"), + ) ) if arg_attrs: - raise NotImplementedError("arg_attrs not implemented in riscv_func") + raise NotImplementedError("arg_attrs not implemented in arm_func") + if res_attrs: + raise NotImplementedError("res_attrs not implemented in arm_func") func = FuncOp(name, region, (input_types, return_types), visibility) if extra_attrs is not None: func.attributes |= extra_attrs.data diff --git a/xdsl/dialects/csl/csl.py b/xdsl/dialects/csl/csl.py index fe9192d47a..6a659ab356 100644 --- a/xdsl/dialects/csl/csl.py +++ b/xdsl/dialects/csl/csl.py @@ -780,17 +780,16 @@ def verify_(self) -> None: @classmethod def parse(cls, parser: Parser) -> FuncOp: - ( - name, - input_types, - return_types, - region, - extra_attrs, - arg_attrs, - ) = parse_func_op_like( - parser, reserved_attr_names=("sym_name", "function_type", "sym_visibility") + (name, input_types, return_types, region, extra_attrs, arg_attrs, res_attrs) = ( + parse_func_op_like( + parser, + reserved_attr_names=("sym_name", "function_type", "sym_visibility"), + ) ) + if res_attrs: + raise NotImplementedError("res_attrs not implemented in csl FuncOp") + assert ( len(return_types) <= 1 ), f"{cls.name} can't have more than one result type!" @@ -890,16 +889,14 @@ def verify_(self) -> None: @classmethod def parse(cls, parser: Parser) -> TaskOp: pos = parser.pos - ( - name, - input_types, - return_types, - region, - extra_attrs, - arg_attrs, - ) = parse_func_op_like( - parser, reserved_attr_names=("sym_name", "function_type", "sym_visibility") + (name, input_types, return_types, region, extra_attrs, arg_attrs, res_attrs) = ( + parse_func_op_like( + parser, + reserved_attr_names=("sym_name", "function_type", "sym_visibility"), + ) ) + if res_attrs: + raise NotImplementedError("res_attrs not implemented in csl TaskOp") if ( extra_attrs is None or "kind" not in extra_attrs.data diff --git a/xdsl/dialects/func.py b/xdsl/dialects/func.py index 279f22d3a1..bf0d62b490 100644 --- a/xdsl/dialects/func.py +++ b/xdsl/dialects/func.py @@ -166,6 +166,7 @@ def parse(cls, parser: Parser) -> FuncOp: region, extra_attrs, arg_attrs, + res_attrs, ) = parse_func_op_like( parser, reserved_attr_names=("sym_name", "function_type", "sym_visibility") ) @@ -175,6 +176,7 @@ def parse(cls, parser: Parser) -> FuncOp: region=region, visibility=visibility, arg_attrs=arg_attrs, + res_attrs=res_attrs, ) if extra_attrs is not None: func.attributes |= extra_attrs.data @@ -192,6 +194,7 @@ def print(self, printer: Printer): self.body, self.attributes, arg_attrs=self.arg_attrs, + res_attrs=self.res_attrs, reserved_attr_names=( "sym_name", "function_type", diff --git a/xdsl/dialects/riscv_func.py b/xdsl/dialects/riscv_func.py index b45ef90794..854fa13e9d 100644 --- a/xdsl/dialects/riscv_func.py +++ b/xdsl/dialects/riscv_func.py @@ -174,18 +174,16 @@ def __init__( @classmethod def parse(cls, parser: Parser) -> FuncOp: visibility = parser.parse_optional_visibility_keyword() - ( - name, - input_types, - return_types, - region, - extra_attrs, - arg_attrs, - ) = parse_func_op_like( - parser, reserved_attr_names=("sym_name", "function_type", "sym_visibility") + (name, input_types, return_types, region, extra_attrs, arg_attrs, res_attrs) = ( + parse_func_op_like( + parser, + reserved_attr_names=("sym_name", "function_type", "sym_visibility"), + ) ) if arg_attrs: raise NotImplementedError("arg_attrs not implemented in riscv_func") + if res_attrs: + raise NotImplementedError("res_attrs not implemented in riscv_func") func = FuncOp(name, region, (input_types, return_types), visibility) if extra_attrs is not None: func.attributes |= extra_attrs.data diff --git a/xdsl/dialects/utils/format.py b/xdsl/dialects/utils/format.py index 8c4d7074dd..3b0a5c91c2 100644 --- a/xdsl/dialects/utils/format.py +++ b/xdsl/dialects/utils/format.py @@ -43,6 +43,7 @@ def print_func_op_like( attributes: dict[str, Attribute], *, arg_attrs: ArrayAttr[DictionaryAttr] | None = None, + res_attrs: ArrayAttr[DictionaryAttr] | None = None, reserved_attr_names: Sequence[str], ): printer.print(f" @{sym_name.data}") @@ -62,7 +63,15 @@ def print_func_op_like( printer.print("-> ") if len(function_type.outputs) > 1: printer.print("(") - printer.print_list(function_type.outputs, printer.print_attribute) + if res_attrs is not None: + printer.print_list( + zip(function_type.outputs, res_attrs), + lambda arg_with_attrs: print_func_output( + printer, arg_with_attrs[0], arg_with_attrs[1] + ), + ) + else: + printer.print_list(function_type.outputs, printer.print_attribute) if len(function_type.outputs) > 1: printer.print(")") printer.print(" ") @@ -85,9 +94,10 @@ def parse_func_op_like( Region, DictionaryAttr | None, ArrayAttr[DictionaryAttr] | None, + ArrayAttr[DictionaryAttr] | None, ]: """ - Returns the function name, argument types, return types, body, extra args, and arg_attrs. + Returns the function name, argument types, return types, body, extra args, arg_attrs and res_attrs. """ # Parse function name name = parser.parse_symbol_name().data @@ -103,6 +113,13 @@ def parse_fun_input() -> Attribute | tuple[Parser.Argument, dict[str, Attribute] ret = (arg, arg_attr_dict) return ret + def parse_fun_output() -> tuple[Attribute, dict[str, Attribute]]: + arg_type = parser.parse_optional_type() + if arg_type is None: + parser.raise_error("Return type should be specified") + arg_attr_dict = parser.parse_optional_dictionary_attr_dict() + return (arg_type, arg_attr_dict) + # Parse function arguments args = parser.parse_comma_separated_list( parser.Delimiter.PAREN, @@ -135,14 +152,25 @@ def parse_fun_input() -> Attribute | tuple[Parser.Argument, dict[str, Attribute] arg_attrs = None # Parse return type + return_types: list[Attribute] = [] + res_attrs_raw: list[dict[str, Attribute]] | None = [] if parser.parse_optional_punctuation("->"): - return_types = parser.parse_optional_comma_separated_list( - parser.Delimiter.PAREN, parser.parse_type + return_attributes = parser.parse_optional_comma_separated_list( + parser.Delimiter.PAREN, parse_fun_output ) - if return_types is None: - return_types = [parser.parse_type()] + if return_attributes is None: + # output attributes are supported only if return results are enclosed in brackets (...) + return_types, res_attrs_raw = [parser.parse_type()], None + else: + return_types, res_attrs_raw = ( + [el[0] for el in return_attributes], + [el[1] for el in return_attributes], + ) + + if res_attrs_raw is not None and any(res_attrs_raw): + res_attrs = ArrayAttr(DictionaryAttr(attrs) for attrs in res_attrs_raw) else: - return_types = [] + res_attrs = None extra_attributes = parser.parse_optional_attr_dict_with_keyword(reserved_attr_names) @@ -151,7 +179,15 @@ def parse_fun_input() -> Attribute | tuple[Parser.Argument, dict[str, Attribute] if region is None: region = Region() - return name, input_types, return_types, region, extra_attributes, arg_attrs + return ( + name, + input_types, + return_types, + region, + extra_attributes, + arg_attrs, + res_attrs, + ) def print_func_argument( @@ -162,6 +198,14 @@ def print_func_argument( printer.print_op_attributes(attrs.data) +def print_func_output( + printer: Printer, out_type: Attribute, attrs: DictionaryAttr | None +): + printer.print_attribute(out_type) + if attrs is not None and attrs.data: + printer.print_op_attributes(attrs.data) + + def print_assignment(printer: Printer, arg: BlockArgument, val: SSAValue): printer.print_block_argument(arg, print_type=False) printer.print_string(" = ")