Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: Add attributes to func outputs #3661

Merged
merged 9 commits into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions tests/filecheck/dialects/func/func_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,14 @@ builtin.module {
// CHECK: func.func public @arg_attrs(%{{.*}} : tensor<8x8xf64> {"llvm.noalias"}, %{{.*}} : tensor<8x8xf64> {"llvm.noalias"}, %{{.*}} : tensor<8x8xf64> {"llvm.noalias"}) -> tensor<8x8xf64> {
// CHECK-NEXT: return %{{.*}} : tensor<8x8xf64>
// CHECK-NEXT: }

func.func @output_attributes() -> (f32 {dialect.a = 0 : i32}, f32 {dialect.b = 0 : i32, dialect.c = 1 : i64}) {
%r1, %r2 = "test.op"() : () -> (f32, f32)
return %r1, %r2 : f32, f32
}

// CHECK: func.func @output_attributes() -> (f32 {"dialect.a" = 0 : i32}, f32 {"dialect.b" = 0 : i32, "dialect.c" = 1 : i64}) {
// CHECK-NEXT: %r1, %r2 = "test.op"() : () -> (f32, f32)
// CHECK-NEXT: func.return %r1, %r2 : f32, f32
// CHECK-NEXT: }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to have a generic format filecheck and an mlir integration filecheck for this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Done.

}
10 changes: 10 additions & 0 deletions tests/filecheck/dialects/func/func_ops_generic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,13 @@
// CHECK-NEXT: ^0(%arg0 : tensor<8x8xf64>, %arg1 : tensor<8x8xf64>):
// CHECK-NEXT: "func.return"(%arg0, %arg1) : (tensor<8x8xf64>, tensor<8x8xf64>) -> ()
// CHECK-NEXT: }) : () -> ()

func.func @output_attributes() -> (f32 {dialect.a = 0 : i32}, f32 {dialect.b = 0 : i32, dialect.c = 1 : i64}) {
%r1, %r2 = "test.op"() : () -> (f32, f32)
return %r1, %r2 : f32, f32
}

// CHECK: "func.func"() <{"sym_name" = "output_attributes", "function_type" = () -> (f32, f32), "res_attrs" = [{"dialect.a" = 0 : i32}, {"dialect.b" = 0 : i32, "dialect.c" = 1 : i64}]}> ({
// CHECK-NEXT: %r1, %r2 = "test.op"() : () -> (f32, f32)
// CHECK-NEXT: "func.return"(%r1, %r2) : (f32, f32) -> ()
// CHECK-NEXT: }) : () -> ()
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,14 @@ builtin.module {
// CHECK: func.func public @arg_attrs(%{{.*}}: tensor<8x8xf64> {"llvm.noalias"}, %{{.*}}: tensor<8x8xf64> {"llvm.noalias"}, %{{.*}}: tensor<8x8xf64> {"llvm.noalias"}) -> tensor<8x8xf64> {
// CHECK-NEXT: func.return %{{.*}} : tensor<8x8xf64>
// CHECK-NEXT: }

func.func @output_attributes() -> (f32 {dialect.a = 0 : i32}, f32 {dialect.b = 0 : i32, dialect.c = 1 : i64}) {
%r1, %r2 = "test.op"() : () -> (f32, f32)
return %r1, %r2 : f32, f32
}

// CHECK: func.func @output_attributes() -> (f32 {"dialect.a" = 0 : i32}, f32 {"dialect.b" = 0 : i32, "dialect.c" = 1 : i64}) {
// CHECK-NEXT: %0, %1 = "test.op"() : () -> (f32, f32)
// CHECK-NEXT: func.return %0, %1 : f32, f32
// CHECK-NEXT: }
}
18 changes: 8 additions & 10 deletions xdsl/dialects/arm_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,18 +83,16 @@ def __init__(
@classmethod
def parse(cls, parser: Parser) -> FuncOp:
visibility = parser.parse_optional_visibility_keyword()
(
name,
input_types,
return_types,
region,
extra_attrs,
arg_attrs,
) = parse_func_op_like(
parser, reserved_attr_names=("sym_name", "function_type", "sym_visibility")
(name, input_types, return_types, region, extra_attrs, arg_attrs, res_attrs) = (
parse_func_op_like(
parser,
reserved_attr_names=("sym_name", "function_type", "sym_visibility"),
)
)
if arg_attrs:
raise NotImplementedError("arg_attrs not implemented in riscv_func")
raise NotImplementedError("arg_attrs not implemented in arm_func")
if res_attrs:
raise NotImplementedError("res_attrs not implemented in arm_func")
func = FuncOp(name, region, (input_types, return_types), visibility)
if extra_attrs is not None:
func.attributes |= extra_attrs.data
Expand Down
33 changes: 15 additions & 18 deletions xdsl/dialects/csl/csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,17 +779,16 @@ def verify_(self) -> None:

@classmethod
def parse(cls, parser: Parser) -> FuncOp:
(
name,
input_types,
return_types,
region,
extra_attrs,
arg_attrs,
) = parse_func_op_like(
parser, reserved_attr_names=("sym_name", "function_type", "sym_visibility")
(name, input_types, return_types, region, extra_attrs, arg_attrs, res_attrs) = (
parse_func_op_like(
parser,
reserved_attr_names=("sym_name", "function_type", "sym_visibility"),
)
)

if res_attrs:
raise NotImplementedError("res_attrs not implemented in csl FuncOp")

assert (
len(return_types) <= 1
), f"{cls.name} can't have more than one result type!"
Expand Down Expand Up @@ -889,16 +888,14 @@ def verify_(self) -> None:
@classmethod
def parse(cls, parser: Parser) -> TaskOp:
pos = parser.pos
(
name,
input_types,
return_types,
region,
extra_attrs,
arg_attrs,
) = parse_func_op_like(
parser, reserved_attr_names=("sym_name", "function_type", "sym_visibility")
(name, input_types, return_types, region, extra_attrs, arg_attrs, res_attrs) = (
parse_func_op_like(
parser,
reserved_attr_names=("sym_name", "function_type", "sym_visibility"),
)
)
if res_attrs:
raise NotImplementedError("res_attrs not implemented in csl TaskOp")
if (
extra_attrs is None
or "kind" not in extra_attrs.data
Expand Down
3 changes: 3 additions & 0 deletions xdsl/dialects/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def parse(cls, parser: Parser) -> FuncOp:
region,
extra_attrs,
arg_attrs,
res_attrs,
) = parse_func_op_like(
parser, reserved_attr_names=("sym_name", "function_type", "sym_visibility")
)
Expand All @@ -138,6 +139,7 @@ def parse(cls, parser: Parser) -> FuncOp:
region=region,
visibility=visibility,
arg_attrs=arg_attrs,
res_attrs=res_attrs,
)
if extra_attrs is not None:
func.attributes |= extra_attrs.data
Expand All @@ -155,6 +157,7 @@ def print(self, printer: Printer):
self.body,
self.attributes,
arg_attrs=self.arg_attrs,
res_attrs=self.res_attrs,
reserved_attr_names=(
"sym_name",
"function_type",
Expand Down
16 changes: 7 additions & 9 deletions xdsl/dialects/riscv_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,18 +174,16 @@ def __init__(
@classmethod
def parse(cls, parser: Parser) -> FuncOp:
visibility = parser.parse_optional_visibility_keyword()
(
name,
input_types,
return_types,
region,
extra_attrs,
arg_attrs,
) = parse_func_op_like(
parser, reserved_attr_names=("sym_name", "function_type", "sym_visibility")
(name, input_types, return_types, region, extra_attrs, arg_attrs, res_attrs) = (
parse_func_op_like(
parser,
reserved_attr_names=("sym_name", "function_type", "sym_visibility"),
)
)
if arg_attrs:
raise NotImplementedError("arg_attrs not implemented in riscv_func")
if res_attrs:
raise NotImplementedError("res_attrs not implemented in riscv_func")
func = FuncOp(name, region, (input_types, return_types), visibility)
if extra_attrs is not None:
func.attributes |= extra_attrs.data
Expand Down
60 changes: 52 additions & 8 deletions xdsl/dialects/utils/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def print_func_op_like(
attributes: dict[str, Attribute],
*,
arg_attrs: ArrayAttr[DictionaryAttr] | None = None,
res_attrs: ArrayAttr[DictionaryAttr] | None = None,
reserved_attr_names: Sequence[str],
):
printer.print(f" @{sym_name.data}")
Expand All @@ -62,7 +63,15 @@ def print_func_op_like(
printer.print("-> ")
if len(function_type.outputs) > 1:
printer.print("(")
printer.print_list(function_type.outputs, printer.print_attribute)
if res_attrs is not None:
printer.print_list(
zip(function_type.outputs, res_attrs),
lambda arg_with_attrs: print_func_output(
printer, arg_with_attrs[0], arg_with_attrs[1]
),
)
else:
printer.print_list(function_type.outputs, printer.print_attribute)
if len(function_type.outputs) > 1:
printer.print(")")
printer.print(" ")
Expand All @@ -85,9 +94,10 @@ def parse_func_op_like(
Region,
DictionaryAttr | None,
ArrayAttr[DictionaryAttr] | None,
ArrayAttr[DictionaryAttr] | None,
]:
"""
Returns the function name, argument types, return types, body, extra args, and arg_attrs.
Returns the function name, argument types, return types, body, extra args, arg_attrs and res_attrs.
"""
# Parse function name
name = parser.parse_symbol_name().data
Expand All @@ -103,6 +113,13 @@ def parse_fun_input() -> Attribute | tuple[Parser.Argument, dict[str, Attribute]
ret = (arg, arg_attr_dict)
return ret

def parse_fun_output() -> tuple[Attribute, dict[str, Attribute]]:
arg_type = parser.parse_optional_type()
if arg_type is None:
parser.raise_error("Return type should be specified")
arg_attr_dict = parser.parse_optional_dictionary_attr_dict()
return (arg_type, arg_attr_dict)

# Parse function arguments
args = parser.parse_comma_separated_list(
parser.Delimiter.PAREN,
Expand Down Expand Up @@ -135,14 +152,25 @@ def parse_fun_input() -> Attribute | tuple[Parser.Argument, dict[str, Attribute]
arg_attrs = None

# Parse return type
return_types: list[Attribute] = []
res_attrs_raw: list[dict[str, Attribute]] | None = []
if parser.parse_optional_punctuation("->"):
return_types = parser.parse_optional_comma_separated_list(
parser.Delimiter.PAREN, parser.parse_type
return_attributes = parser.parse_optional_comma_separated_list(
parser.Delimiter.PAREN, parse_fun_output
)
if return_types is None:
return_types = [parser.parse_type()]
if return_attributes is None:
# output attributes are supported only if return results are enclosed in brackets (...)
return_types, res_attrs_raw = [parser.parse_type()], None
else:
return_types, res_attrs_raw = (
[el[0] for el in return_attributes],
[el[1] for el in return_attributes],
)

if res_attrs_raw is not None and any(res_attrs_raw):
res_attrs = ArrayAttr(DictionaryAttr(attrs) for attrs in res_attrs_raw)
else:
return_types = []
res_attrs = None

extra_attributes = parser.parse_optional_attr_dict_with_keyword(reserved_attr_names)

Expand All @@ -151,7 +179,15 @@ def parse_fun_input() -> Attribute | tuple[Parser.Argument, dict[str, Attribute]
if region is None:
region = Region()

return name, input_types, return_types, region, extra_attributes, arg_attrs
return (
name,
input_types,
return_types,
region,
extra_attributes,
arg_attrs,
res_attrs,
)


def print_func_argument(
Expand All @@ -162,6 +198,14 @@ def print_func_argument(
printer.print_op_attributes(attrs.data)


def print_func_output(
printer: Printer, out_type: Attribute, attrs: DictionaryAttr | None
):
printer.print_attribute(out_type)
if attrs is not None and attrs.data:
printer.print_op_attributes(attrs.data)


def print_assignment(printer: Printer, arg: BlockArgument, val: SSAValue):
printer.print_block_argument(arg, print_type=False)
printer.print_string(" = ")
Expand Down
Loading