From 1d87ea44ecf3b47a79a3e93378f37b1482d0cc4c Mon Sep 17 00:00:00 2001 From: Max Manainen Date: Thu, 19 Dec 2024 19:13:50 +0100 Subject: [PATCH 1/9] basic output attributes --- tests/dialects/test_func.py | 22 ++++++++++++++++++++ uv.lock | 2 +- xdsl/dialects/func.py | 2 ++ xdsl/dialects/utils/format.py | 39 ++++++++++++++++++++++++++++------- 4 files changed, 57 insertions(+), 8 deletions(-) diff --git a/tests/dialects/test_func.py b/tests/dialects/test_func.py index 3679e2ee80..3f4b10904b 100644 --- a/tests/dialects/test_func.py +++ b/tests/dialects/test_func.py @@ -2,10 +2,12 @@ from conftest import assert_print_op from xdsl.builder import Builder, ImplicitBuilder +from xdsl.context import MLContext from xdsl.dialects.arith import AddiOp, ConstantOp from xdsl.dialects.builtin import IntegerAttr, IntegerType, ModuleOp, i32, i64 from xdsl.dialects.func import CallOp, FuncOp, ReturnOp from xdsl.ir import Block, Region +from xdsl.parser import Parser from xdsl.traits import CallableOpInterface from xdsl.utils.exceptions import VerifyException @@ -279,3 +281,23 @@ def test_external_func_def(): assert len(ext.regions) == 1 assert len(ext.regions[0].blocks) == 0 assert ext.sym_name.data == "testname" + + +def test_output_attribute_parsing(): + ctx = MLContext() + ctx.load_op(FuncOp) + parser = Parser( + ctx, + "func.func @test(%arg0: f32 {a = 0 : i32}) -> (f32 {a = 0 : i32}, f32 {a = 0 : i32}) {}", + ) + for func_str in [ + "func.func @test() -> (f32 {a = 0 : i32}, f32 {a = 0 : i32}) {}", + "func.func @test() -> f32 {a = 0 : i32} {}", + ]: + parser = Parser(ctx, func_str) + op = parser.parse_op() + print(op) + assert isinstance(op, FuncOp) + assert op.res_attrs is not None + for attr in op.res_attrs: + assert str(attr.data["a"]) == "0 : i32" diff --git a/uv.lock b/uv.lock index 6262d8ce0e..f1f42704a3 100644 --- a/uv.lock +++ b/uv.lock @@ -2390,7 +2390,7 @@ wheels = [ [[package]] name = "xdsl" -version = "0+dynamic" +version = "0+untagged.3077.gb394263.dirty" source = { editable = "." } dependencies = [ { name = "immutabledict" }, diff --git a/xdsl/dialects/func.py b/xdsl/dialects/func.py index c5a2e28c58..1c5b097b05 100644 --- a/xdsl/dialects/func.py +++ b/xdsl/dialects/func.py @@ -129,6 +129,7 @@ def parse(cls, parser: Parser) -> FuncOp: region, extra_attrs, arg_attrs, + res_attrs, ) = parse_func_op_like( parser, reserved_attr_names=("sym_name", "function_type", "sym_visibility") ) @@ -138,6 +139,7 @@ def parse(cls, parser: Parser) -> FuncOp: region=region, visibility=visibility, arg_attrs=arg_attrs, + res_attrs=res_attrs, ) if extra_attrs is not None: func.attributes |= extra_attrs.data diff --git a/xdsl/dialects/utils/format.py b/xdsl/dialects/utils/format.py index 8c4d7074dd..8960edbc17 100644 --- a/xdsl/dialects/utils/format.py +++ b/xdsl/dialects/utils/format.py @@ -85,9 +85,10 @@ def parse_func_op_like( Region, DictionaryAttr | None, ArrayAttr[DictionaryAttr] | None, + ArrayAttr[DictionaryAttr] | None, ]: """ - Returns the function name, argument types, return types, body, extra args, and arg_attrs. + Returns the function name, argument types, return types, body, extra args, arg_attrs and res_attrs. """ # Parse function name name = parser.parse_symbol_name().data @@ -103,6 +104,13 @@ def parse_fun_input() -> Attribute | tuple[Parser.Argument, dict[str, Attribute] ret = (arg, arg_attr_dict) return ret + def parse_fun_output() -> tuple[Attribute, dict[str, Attribute]]: + arg_type = parser.parse_optional_type() + if arg_type is None: + parser.raise_error("Return type should be specified") + arg_attr_dict = parser.parse_optional_dictionary_attr_dict() + return (arg_type, arg_attr_dict) + # Parse function arguments args = parser.parse_comma_separated_list( parser.Delimiter.PAREN, @@ -135,14 +143,23 @@ def parse_fun_input() -> Attribute | tuple[Parser.Argument, dict[str, Attribute] arg_attrs = None # Parse return type + return_types: list[Attribute] = [] + res_attrs_raw: list[dict[str, Attribute]] = [] if parser.parse_optional_punctuation("->"): - return_types = parser.parse_optional_comma_separated_list( - parser.Delimiter.PAREN, parser.parse_type + return_attributes = parser.parse_optional_comma_separated_list( + parser.Delimiter.PAREN, parse_fun_output ) - if return_types is None: - return_types = [parser.parse_type()] + if return_attributes is None: + return_attributes = [ + (parser.parse_type(), parser.parse_optional_dictionary_attr_dict()) + ] + + return_types, res_attrs_raw = zip(*return_attributes) + + if any(res_attrs_raw): + res_attrs = ArrayAttr(DictionaryAttr(attrs) for attrs in res_attrs_raw) else: - return_types = [] + res_attrs = None extra_attributes = parser.parse_optional_attr_dict_with_keyword(reserved_attr_names) @@ -151,7 +168,15 @@ def parse_fun_input() -> Attribute | tuple[Parser.Argument, dict[str, Attribute] if region is None: region = Region() - return name, input_types, return_types, region, extra_attributes, arg_attrs + return ( + name, + input_types, + return_types, + region, + extra_attributes, + arg_attrs, + res_attrs, + ) def print_func_argument( From be9353eb8ffaeb31a148d25a2f96e9df5c9f0c48 Mon Sep 17 00:00:00 2001 From: Max Manainen Date: Fri, 20 Dec 2024 11:18:03 +0100 Subject: [PATCH 2/9] temp --- tests/dialects/test_func.py | 6 ++++-- tests/filecheck/dialects/func/func_ops.mlir | 12 ++++++++---- xdsl/dialects/func.py | 1 + xdsl/dialects/utils/format.py | 19 ++++++++++++++++++- 4 files changed, 31 insertions(+), 7 deletions(-) diff --git a/tests/dialects/test_func.py b/tests/dialects/test_func.py index 3f4b10904b..165075d106 100644 --- a/tests/dialects/test_func.py +++ b/tests/dialects/test_func.py @@ -6,6 +6,7 @@ from xdsl.dialects.arith import AddiOp, ConstantOp from xdsl.dialects.builtin import IntegerAttr, IntegerType, ModuleOp, i32, i64 from xdsl.dialects.func import CallOp, FuncOp, ReturnOp +from xdsl.dialects.test import TestOp from xdsl.ir import Block, Region from xdsl.parser import Parser from xdsl.traits import CallableOpInterface @@ -286,13 +287,14 @@ def test_external_func_def(): def test_output_attribute_parsing(): ctx = MLContext() ctx.load_op(FuncOp) + ctx.load_op(TestOp) parser = Parser( ctx, "func.func @test(%arg0: f32 {a = 0 : i32}) -> (f32 {a = 0 : i32}, f32 {a = 0 : i32}) {}", ) for func_str in [ - "func.func @test() -> (f32 {a = 0 : i32}, f32 {a = 0 : i32}) {}", - "func.func @test() -> f32 {a = 0 : i32} {}", + 'func.func @test() -> (f32 {a = 0 : i32}, f32 {a = 0 : i32}) {"test.op"() : () -> ()}', + 'func.func @test(%a : f32 {a = 0 : i32}) -> f32 {a = 0 : i32} {"test.op"() : () -> ()}', ]: parser = Parser(ctx, func_str) op = parser.parse_op() diff --git a/tests/filecheck/dialects/func/func_ops.mlir b/tests/filecheck/dialects/func/func_ops.mlir index 608d4f2afb..808f9371e8 100644 --- a/tests/filecheck/dialects/func/func_ops.mlir +++ b/tests/filecheck/dialects/func/func_ops.mlir @@ -30,10 +30,10 @@ builtin.module { // CHECK-NEXT: func.return // CHECK-NEXT: } - func.func @arg_rec(%0 : !test.type<"int">) -> !test.type<"int"> { - %1 = func.call @arg_rec(%0) : (!test.type<"int">) -> !test.type<"int"> - func.return %1 : !test.type<"int"> - } + // problem func.func @arg_rec(%0 : !test.type<"int">) -> !test.type<"int"> { + // %1 = func.call @arg_rec(%0) : (!test.type<"int">) -> !test.type<"int"> + // func.return %1 : !test.type<"int"> + //} // CHECK: func.func @arg_rec(%0 : !test.type<"int">) -> !test.type<"int"> { // CHECK-NEXT: %{{.*}} = func.call @arg_rec(%{{.*}}) : (!test.type<"int">) -> !test.type<"int"> @@ -71,4 +71,8 @@ builtin.module { // CHECK: func.func public @arg_attrs(%{{.*}} : tensor<8x8xf64> {"llvm.noalias"}, %{{.*}} : tensor<8x8xf64> {"llvm.noalias"}, %{{.*}} : tensor<8x8xf64> {"llvm.noalias"}) -> tensor<8x8xf64> { // CHECK-NEXT: return %{{.*}} : tensor<8x8xf64> // CHECK-NEXT: } + + func.func @output_attributes() -> (f32 {a = 0 : i32}, f32 {b = 0 : i32, c = 1 : f64}) { + func.return + } } diff --git a/xdsl/dialects/func.py b/xdsl/dialects/func.py index 1c5b097b05..a60ba66f3d 100644 --- a/xdsl/dialects/func.py +++ b/xdsl/dialects/func.py @@ -157,6 +157,7 @@ def print(self, printer: Printer): self.body, self.attributes, arg_attrs=self.arg_attrs, + res_attrs=self.res_attrs, reserved_attr_names=( "sym_name", "function_type", diff --git a/xdsl/dialects/utils/format.py b/xdsl/dialects/utils/format.py index 8960edbc17..08642a1762 100644 --- a/xdsl/dialects/utils/format.py +++ b/xdsl/dialects/utils/format.py @@ -43,6 +43,7 @@ def print_func_op_like( attributes: dict[str, Attribute], *, arg_attrs: ArrayAttr[DictionaryAttr] | None = None, + res_attrs: ArrayAttr[DictionaryAttr] | None = None, reserved_attr_names: Sequence[str], ): printer.print(f" @{sym_name.data}") @@ -62,7 +63,15 @@ def print_func_op_like( printer.print("-> ") if len(function_type.outputs) > 1: printer.print("(") - printer.print_list(function_type.outputs, printer.print_attribute) + if res_attrs is not None: + printer.print_list( + zip(function_type.outputs, res_attrs), + lambda arg_with_attrs: print_func_output( + printer, arg_with_attrs[0], arg_with_attrs[1] + ), + ) + else: + printer.print_list(function_type.outputs, printer.print_attribute) if len(function_type.outputs) > 1: printer.print(")") printer.print(" ") @@ -187,6 +196,14 @@ def print_func_argument( printer.print_op_attributes(attrs.data) +def print_func_output( + printer: Printer, out_type: Attribute, attrs: DictionaryAttr | None +): + printer.print_attribute(out_type) + if attrs is not None and attrs.data: + printer.print_op_attributes(attrs.data) + + def print_assignment(printer: Printer, arg: BlockArgument, val: SSAValue): printer.print_block_argument(arg, print_type=False) printer.print_string(" = ") From 2027c94be69db8b9290a24f5c36b663f0b798798 Mon Sep 17 00:00:00 2001 From: Max Manainen Date: Fri, 20 Dec 2024 11:52:41 +0100 Subject: [PATCH 3/9] roundtrip test --- tests/dialects/test_func.py | 24 --------------------- tests/filecheck/dialects/func/func_ops.mlir | 18 ++++++++++------ xdsl/dialects/utils/format.py | 5 ++--- 3 files changed, 14 insertions(+), 33 deletions(-) diff --git a/tests/dialects/test_func.py b/tests/dialects/test_func.py index 165075d106..3679e2ee80 100644 --- a/tests/dialects/test_func.py +++ b/tests/dialects/test_func.py @@ -2,13 +2,10 @@ from conftest import assert_print_op from xdsl.builder import Builder, ImplicitBuilder -from xdsl.context import MLContext from xdsl.dialects.arith import AddiOp, ConstantOp from xdsl.dialects.builtin import IntegerAttr, IntegerType, ModuleOp, i32, i64 from xdsl.dialects.func import CallOp, FuncOp, ReturnOp -from xdsl.dialects.test import TestOp from xdsl.ir import Block, Region -from xdsl.parser import Parser from xdsl.traits import CallableOpInterface from xdsl.utils.exceptions import VerifyException @@ -282,24 +279,3 @@ def test_external_func_def(): assert len(ext.regions) == 1 assert len(ext.regions[0].blocks) == 0 assert ext.sym_name.data == "testname" - - -def test_output_attribute_parsing(): - ctx = MLContext() - ctx.load_op(FuncOp) - ctx.load_op(TestOp) - parser = Parser( - ctx, - "func.func @test(%arg0: f32 {a = 0 : i32}) -> (f32 {a = 0 : i32}, f32 {a = 0 : i32}) {}", - ) - for func_str in [ - 'func.func @test() -> (f32 {a = 0 : i32}, f32 {a = 0 : i32}) {"test.op"() : () -> ()}', - 'func.func @test(%a : f32 {a = 0 : i32}) -> f32 {a = 0 : i32} {"test.op"() : () -> ()}', - ]: - parser = Parser(ctx, func_str) - op = parser.parse_op() - print(op) - assert isinstance(op, FuncOp) - assert op.res_attrs is not None - for attr in op.res_attrs: - assert str(attr.data["a"]) == "0 : i32" diff --git a/tests/filecheck/dialects/func/func_ops.mlir b/tests/filecheck/dialects/func/func_ops.mlir index 808f9371e8..44a3a05a33 100644 --- a/tests/filecheck/dialects/func/func_ops.mlir +++ b/tests/filecheck/dialects/func/func_ops.mlir @@ -30,10 +30,10 @@ builtin.module { // CHECK-NEXT: func.return // CHECK-NEXT: } - // problem func.func @arg_rec(%0 : !test.type<"int">) -> !test.type<"int"> { - // %1 = func.call @arg_rec(%0) : (!test.type<"int">) -> !test.type<"int"> - // func.return %1 : !test.type<"int"> - //} + func.func @arg_rec(%0 : !test.type<"int">) -> !test.type<"int"> { + %1 = func.call @arg_rec(%0) : (!test.type<"int">) -> !test.type<"int"> + func.return %1 : !test.type<"int"> + } // CHECK: func.func @arg_rec(%0 : !test.type<"int">) -> !test.type<"int"> { // CHECK-NEXT: %{{.*}} = func.call @arg_rec(%{{.*}}) : (!test.type<"int">) -> !test.type<"int"> @@ -72,7 +72,13 @@ builtin.module { // CHECK-NEXT: return %{{.*}} : tensor<8x8xf64> // CHECK-NEXT: } - func.func @output_attributes() -> (f32 {a = 0 : i32}, f32 {b = 0 : i32, c = 1 : f64}) { - func.return + func.func @output_attributes() -> (f32 {a = 0 : i32}, f32 {b = 0 : i32, c = 1 : i64}) { + %r1, %r2 = "test.op"() : () -> (f32, f32) + return %r1, %r2 : f32, f32 } + + // CHECK: func.func @output_attributes() -> (f32 {"a" = 0 : i32}, f32 {"b" = 0 : i32, "c" = 1 : i64}) { + // CHECK-NEXT: %r1, %r2 = "test.op"() : () -> (f32, f32) + // CHECK-NEXT: func.return %r1, %r2 : f32, f32 + // CHECK-NEXT: } } diff --git a/xdsl/dialects/utils/format.py b/xdsl/dialects/utils/format.py index 08642a1762..f360eb6a77 100644 --- a/xdsl/dialects/utils/format.py +++ b/xdsl/dialects/utils/format.py @@ -159,9 +159,8 @@ def parse_fun_output() -> tuple[Attribute, dict[str, Attribute]]: parser.Delimiter.PAREN, parse_fun_output ) if return_attributes is None: - return_attributes = [ - (parser.parse_type(), parser.parse_optional_dictionary_attr_dict()) - ] + # output attributes are supported only if return results are enclosed in brackets (...) + return_attributes = [(parser.parse_type(), None)] return_types, res_attrs_raw = zip(*return_attributes) From 79f7194d0c8daf4cf95f6e2d446660f546ea83a5 Mon Sep 17 00:00:00 2001 From: Max Manainen Date: Fri, 20 Dec 2024 11:54:27 +0100 Subject: [PATCH 4/9] roundtrip test --- xdsl/dialects/arm_func.py | 14 +++++--------- xdsl/dialects/csl/csl.py | 14 +++++--------- xdsl/dialects/riscv_func.py | 14 +++++--------- 3 files changed, 15 insertions(+), 27 deletions(-) diff --git a/xdsl/dialects/arm_func.py b/xdsl/dialects/arm_func.py index 7b4feb24bc..feaafb931f 100644 --- a/xdsl/dialects/arm_func.py +++ b/xdsl/dialects/arm_func.py @@ -83,15 +83,11 @@ def __init__( @classmethod def parse(cls, parser: Parser) -> FuncOp: visibility = parser.parse_optional_visibility_keyword() - ( - name, - input_types, - return_types, - region, - extra_attrs, - arg_attrs, - ) = parse_func_op_like( - parser, reserved_attr_names=("sym_name", "function_type", "sym_visibility") + (name, input_types, return_types, region, extra_attrs, arg_attrs, _) = ( + parse_func_op_like( + parser, + reserved_attr_names=("sym_name", "function_type", "sym_visibility"), + ) ) if arg_attrs: raise NotImplementedError("arg_attrs not implemented in riscv_func") diff --git a/xdsl/dialects/csl/csl.py b/xdsl/dialects/csl/csl.py index e8f6da22e0..c199b827ff 100644 --- a/xdsl/dialects/csl/csl.py +++ b/xdsl/dialects/csl/csl.py @@ -889,15 +889,11 @@ def verify_(self) -> None: @classmethod def parse(cls, parser: Parser) -> TaskOp: pos = parser.pos - ( - name, - input_types, - return_types, - region, - extra_attrs, - arg_attrs, - ) = parse_func_op_like( - parser, reserved_attr_names=("sym_name", "function_type", "sym_visibility") + (name, input_types, return_types, region, extra_attrs, arg_attrs, _) = ( + parse_func_op_like( + parser, + reserved_attr_names=("sym_name", "function_type", "sym_visibility"), + ) ) if ( extra_attrs is None diff --git a/xdsl/dialects/riscv_func.py b/xdsl/dialects/riscv_func.py index b45ef90794..2e6a7d7138 100644 --- a/xdsl/dialects/riscv_func.py +++ b/xdsl/dialects/riscv_func.py @@ -174,15 +174,11 @@ def __init__( @classmethod def parse(cls, parser: Parser) -> FuncOp: visibility = parser.parse_optional_visibility_keyword() - ( - name, - input_types, - return_types, - region, - extra_attrs, - arg_attrs, - ) = parse_func_op_like( - parser, reserved_attr_names=("sym_name", "function_type", "sym_visibility") + (name, input_types, return_types, region, extra_attrs, arg_attrs, _) = ( + parse_func_op_like( + parser, + reserved_attr_names=("sym_name", "function_type", "sym_visibility"), + ) ) if arg_attrs: raise NotImplementedError("arg_attrs not implemented in riscv_func") From a3bcb870d9d7f6d13b3f7adaff6d823c5032b3c1 Mon Sep 17 00:00:00 2001 From: Max Manainen Date: Fri, 20 Dec 2024 12:11:49 +0100 Subject: [PATCH 5/9] lock file change --- uv.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/uv.lock b/uv.lock index f1f42704a3..6262d8ce0e 100644 --- a/uv.lock +++ b/uv.lock @@ -2390,7 +2390,7 @@ wheels = [ [[package]] name = "xdsl" -version = "0+untagged.3077.gb394263.dirty" +version = "0+dynamic" source = { editable = "." } dependencies = [ { name = "immutabledict" }, From c5deb619d45a3f6e96b5440d1ff2a39776734d7d Mon Sep 17 00:00:00 2001 From: Max Manainen Date: Fri, 20 Dec 2024 12:14:40 +0100 Subject: [PATCH 6/9] style fix --- xdsl/dialects/csl/csl.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/xdsl/dialects/csl/csl.py b/xdsl/dialects/csl/csl.py index c199b827ff..19279d9fe3 100644 --- a/xdsl/dialects/csl/csl.py +++ b/xdsl/dialects/csl/csl.py @@ -779,15 +779,11 @@ def verify_(self) -> None: @classmethod def parse(cls, parser: Parser) -> FuncOp: - ( - name, - input_types, - return_types, - region, - extra_attrs, - arg_attrs, - ) = parse_func_op_like( - parser, reserved_attr_names=("sym_name", "function_type", "sym_visibility") + (name, input_types, return_types, region, extra_attrs, arg_attrs, _) = ( + parse_func_op_like( + parser, + reserved_attr_names=("sym_name", "function_type", "sym_visibility"), + ) ) assert ( From 5c401b0da0dc010ff05d1106c765933d9c619a2c Mon Sep 17 00:00:00 2001 From: Max Manainen Date: Fri, 20 Dec 2024 13:00:00 +0100 Subject: [PATCH 7/9] fix fail --- xdsl/dialects/utils/format.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/xdsl/dialects/utils/format.py b/xdsl/dialects/utils/format.py index f360eb6a77..3b0a5c91c2 100644 --- a/xdsl/dialects/utils/format.py +++ b/xdsl/dialects/utils/format.py @@ -153,18 +153,21 @@ def parse_fun_output() -> tuple[Attribute, dict[str, Attribute]]: # Parse return type return_types: list[Attribute] = [] - res_attrs_raw: list[dict[str, Attribute]] = [] + res_attrs_raw: list[dict[str, Attribute]] | None = [] if parser.parse_optional_punctuation("->"): return_attributes = parser.parse_optional_comma_separated_list( parser.Delimiter.PAREN, parse_fun_output ) if return_attributes is None: # output attributes are supported only if return results are enclosed in brackets (...) - return_attributes = [(parser.parse_type(), None)] - - return_types, res_attrs_raw = zip(*return_attributes) + return_types, res_attrs_raw = [parser.parse_type()], None + else: + return_types, res_attrs_raw = ( + [el[0] for el in return_attributes], + [el[1] for el in return_attributes], + ) - if any(res_attrs_raw): + if res_attrs_raw is not None and any(res_attrs_raw): res_attrs = ArrayAttr(DictionaryAttr(attrs) for attrs in res_attrs_raw) else: res_attrs = None From 188f0cbbbd2d4d5f688cc5a494a236eb7fb779d5 Mon Sep 17 00:00:00 2001 From: Max Manainen Date: Fri, 20 Dec 2024 21:10:04 +0100 Subject: [PATCH 8/9] more tests --- tests/filecheck/dialects/func/func_ops.mlir | 4 ++-- tests/filecheck/dialects/func/func_ops_generic.mlir | 10 ++++++++++ .../with-mlir/dialects/func/func_ops.mlir | 10 ++++++++++ 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/tests/filecheck/dialects/func/func_ops.mlir b/tests/filecheck/dialects/func/func_ops.mlir index 44a3a05a33..54563a3b18 100644 --- a/tests/filecheck/dialects/func/func_ops.mlir +++ b/tests/filecheck/dialects/func/func_ops.mlir @@ -72,12 +72,12 @@ builtin.module { // CHECK-NEXT: return %{{.*}} : tensor<8x8xf64> // CHECK-NEXT: } - func.func @output_attributes() -> (f32 {a = 0 : i32}, f32 {b = 0 : i32, c = 1 : i64}) { + func.func @output_attributes() -> (f32 {dialect.a = 0 : i32}, f32 {dialect.b = 0 : i32, dialect.c = 1 : i64}) { %r1, %r2 = "test.op"() : () -> (f32, f32) return %r1, %r2 : f32, f32 } - // CHECK: func.func @output_attributes() -> (f32 {"a" = 0 : i32}, f32 {"b" = 0 : i32, "c" = 1 : i64}) { + // CHECK: func.func @output_attributes() -> (f32 {"dialect.a" = 0 : i32}, f32 {"dialect.b" = 0 : i32, "dialect.c" = 1 : i64}) { // CHECK-NEXT: %r1, %r2 = "test.op"() : () -> (f32, f32) // CHECK-NEXT: func.return %r1, %r2 : f32, f32 // CHECK-NEXT: } diff --git a/tests/filecheck/dialects/func/func_ops_generic.mlir b/tests/filecheck/dialects/func/func_ops_generic.mlir index d35b82177c..471a0b117a 100644 --- a/tests/filecheck/dialects/func/func_ops_generic.mlir +++ b/tests/filecheck/dialects/func/func_ops_generic.mlir @@ -10,3 +10,13 @@ // CHECK-NEXT: ^0(%arg0 : tensor<8x8xf64>, %arg1 : tensor<8x8xf64>): // CHECK-NEXT: "func.return"(%arg0, %arg1) : (tensor<8x8xf64>, tensor<8x8xf64>) -> () // CHECK-NEXT: }) : () -> () + +func.func @output_attributes() -> (f32 {dialect.a = 0 : i32}, f32 {dialect.b = 0 : i32, dialect.c = 1 : i64}) { + %r1, %r2 = "test.op"() : () -> (f32, f32) + return %r1, %r2 : f32, f32 +} + +// CHECK: "func.func"() <{"sym_name" = "output_attributes", "function_type" = () -> (f32, f32), "res_attrs" = [{"dialect.a" = 0 : i32}, {"dialect.b" = 0 : i32, "dialect.c" = 1 : i64}]}> ({ +// CHECK-NEXT: %r1, %r2 = "test.op"() : () -> (f32, f32) +// CHECK-NEXT: "func.return"(%r1, %r2) : (f32, f32) -> () +// CHECK-NEXT: }) : () -> () diff --git a/tests/filecheck/mlir-conversion/with-mlir/dialects/func/func_ops.mlir b/tests/filecheck/mlir-conversion/with-mlir/dialects/func/func_ops.mlir index cdfda54a88..c916c0bc72 100644 --- a/tests/filecheck/mlir-conversion/with-mlir/dialects/func/func_ops.mlir +++ b/tests/filecheck/mlir-conversion/with-mlir/dialects/func/func_ops.mlir @@ -72,4 +72,14 @@ builtin.module { // CHECK: func.func public @arg_attrs(%{{.*}}: tensor<8x8xf64> {"llvm.noalias"}, %{{.*}}: tensor<8x8xf64> {"llvm.noalias"}, %{{.*}}: tensor<8x8xf64> {"llvm.noalias"}) -> tensor<8x8xf64> { // CHECK-NEXT: func.return %{{.*}} : tensor<8x8xf64> // CHECK-NEXT: } + + func.func @output_attributes() -> (f32 {dialect.a = 0 : i32}, f32 {dialect.b = 0 : i32, dialect.c = 1 : i64}) { + %r1, %r2 = "test.op"() : () -> (f32, f32) + return %r1, %r2 : f32, f32 + } + + // CHECK: func.func @output_attributes() -> (f32 {"dialect.a" = 0 : i32}, f32 {"dialect.b" = 0 : i32, "dialect.c" = 1 : i64}) { + // CHECK-NEXT: %0, %1 = "test.op"() : () -> (f32, f32) + // CHECK-NEXT: func.return %0, %1 : f32, f32 + // CHECK-NEXT: } } From 8a83c76931c9a45bc83e5fb873014db0b5910e06 Mon Sep 17 00:00:00 2001 From: Max Manainen Date: Fri, 20 Dec 2024 21:15:29 +0100 Subject: [PATCH 9/9] not implemented error --- xdsl/dialects/arm_func.py | 6 ++++-- xdsl/dialects/csl/csl.py | 9 +++++++-- xdsl/dialects/riscv_func.py | 4 +++- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/xdsl/dialects/arm_func.py b/xdsl/dialects/arm_func.py index feaafb931f..e05f28b106 100644 --- a/xdsl/dialects/arm_func.py +++ b/xdsl/dialects/arm_func.py @@ -83,14 +83,16 @@ def __init__( @classmethod def parse(cls, parser: Parser) -> FuncOp: visibility = parser.parse_optional_visibility_keyword() - (name, input_types, return_types, region, extra_attrs, arg_attrs, _) = ( + (name, input_types, return_types, region, extra_attrs, arg_attrs, res_attrs) = ( parse_func_op_like( parser, reserved_attr_names=("sym_name", "function_type", "sym_visibility"), ) ) if arg_attrs: - raise NotImplementedError("arg_attrs not implemented in riscv_func") + raise NotImplementedError("arg_attrs not implemented in arm_func") + if res_attrs: + raise NotImplementedError("res_attrs not implemented in arm_func") func = FuncOp(name, region, (input_types, return_types), visibility) if extra_attrs is not None: func.attributes |= extra_attrs.data diff --git a/xdsl/dialects/csl/csl.py b/xdsl/dialects/csl/csl.py index 19279d9fe3..622396ab51 100644 --- a/xdsl/dialects/csl/csl.py +++ b/xdsl/dialects/csl/csl.py @@ -779,13 +779,16 @@ def verify_(self) -> None: @classmethod def parse(cls, parser: Parser) -> FuncOp: - (name, input_types, return_types, region, extra_attrs, arg_attrs, _) = ( + (name, input_types, return_types, region, extra_attrs, arg_attrs, res_attrs) = ( parse_func_op_like( parser, reserved_attr_names=("sym_name", "function_type", "sym_visibility"), ) ) + if res_attrs: + raise NotImplementedError("res_attrs not implemented in csl FuncOp") + assert ( len(return_types) <= 1 ), f"{cls.name} can't have more than one result type!" @@ -885,12 +888,14 @@ def verify_(self) -> None: @classmethod def parse(cls, parser: Parser) -> TaskOp: pos = parser.pos - (name, input_types, return_types, region, extra_attrs, arg_attrs, _) = ( + (name, input_types, return_types, region, extra_attrs, arg_attrs, res_attrs) = ( parse_func_op_like( parser, reserved_attr_names=("sym_name", "function_type", "sym_visibility"), ) ) + if res_attrs: + raise NotImplementedError("res_attrs not implemented in csl TaskOp") if ( extra_attrs is None or "kind" not in extra_attrs.data diff --git a/xdsl/dialects/riscv_func.py b/xdsl/dialects/riscv_func.py index 2e6a7d7138..854fa13e9d 100644 --- a/xdsl/dialects/riscv_func.py +++ b/xdsl/dialects/riscv_func.py @@ -174,7 +174,7 @@ def __init__( @classmethod def parse(cls, parser: Parser) -> FuncOp: visibility = parser.parse_optional_visibility_keyword() - (name, input_types, return_types, region, extra_attrs, arg_attrs, _) = ( + (name, input_types, return_types, region, extra_attrs, arg_attrs, res_attrs) = ( parse_func_op_like( parser, reserved_attr_names=("sym_name", "function_type", "sym_visibility"), @@ -182,6 +182,8 @@ def parse(cls, parser: Parser) -> FuncOp: ) if arg_attrs: raise NotImplementedError("arg_attrs not implemented in riscv_func") + if res_attrs: + raise NotImplementedError("res_attrs not implemented in riscv_func") func = FuncOp(name, region, (input_types, return_types), visibility) if extra_attrs is not None: func.attributes |= extra_attrs.data