Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: Add attributes to func outputs #3661

Merged
merged 9 commits into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions tests/filecheck/dialects/func/func_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,14 @@ builtin.module {
// CHECK: func.func public @arg_attrs(%{{.*}} : tensor<8x8xf64> {"llvm.noalias"}, %{{.*}} : tensor<8x8xf64> {"llvm.noalias"}, %{{.*}} : tensor<8x8xf64> {"llvm.noalias"}) -> tensor<8x8xf64> {
// CHECK-NEXT: return %{{.*}} : tensor<8x8xf64>
// CHECK-NEXT: }

func.func @output_attributes() -> (f32 {a = 0 : i32}, f32 {b = 0 : i32, c = 1 : i64}) {
%r1, %r2 = "test.op"() : () -> (f32, f32)
return %r1, %r2 : f32, f32
}

// CHECK: func.func @output_attributes() -> (f32 {"a" = 0 : i32}, f32 {"b" = 0 : i32, "c" = 1 : i64}) {
// CHECK-NEXT: %r1, %r2 = "test.op"() : () -> (f32, f32)
// CHECK-NEXT: func.return %r1, %r2 : f32, f32
// CHECK-NEXT: }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to have a generic format filecheck and an mlir integration filecheck for this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Done.

}
14 changes: 5 additions & 9 deletions xdsl/dialects/arm_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,11 @@ def __init__(
@classmethod
def parse(cls, parser: Parser) -> FuncOp:
visibility = parser.parse_optional_visibility_keyword()
(
name,
input_types,
return_types,
region,
extra_attrs,
arg_attrs,
) = parse_func_op_like(
parser, reserved_attr_names=("sym_name", "function_type", "sym_visibility")
(name, input_types, return_types, region, extra_attrs, arg_attrs, _) = (
parse_func_op_like(
parser,
reserved_attr_names=("sym_name", "function_type", "sym_visibility"),
)
)
if arg_attrs:
raise NotImplementedError("arg_attrs not implemented in riscv_func")
Expand Down
28 changes: 10 additions & 18 deletions xdsl/dialects/csl/csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,15 +779,11 @@ def verify_(self) -> None:

@classmethod
def parse(cls, parser: Parser) -> FuncOp:
(
name,
input_types,
return_types,
region,
extra_attrs,
arg_attrs,
) = parse_func_op_like(
parser, reserved_attr_names=("sym_name", "function_type", "sym_visibility")
(name, input_types, return_types, region, extra_attrs, arg_attrs, _) = (
parse_func_op_like(
parser,
reserved_attr_names=("sym_name", "function_type", "sym_visibility"),
)
)

assert (
Expand Down Expand Up @@ -889,15 +885,11 @@ def verify_(self) -> None:
@classmethod
def parse(cls, parser: Parser) -> TaskOp:
pos = parser.pos
(
name,
input_types,
return_types,
region,
extra_attrs,
arg_attrs,
) = parse_func_op_like(
parser, reserved_attr_names=("sym_name", "function_type", "sym_visibility")
(name, input_types, return_types, region, extra_attrs, arg_attrs, _) = (
parse_func_op_like(
parser,
reserved_attr_names=("sym_name", "function_type", "sym_visibility"),
)
)
if (
extra_attrs is None
Expand Down
3 changes: 3 additions & 0 deletions xdsl/dialects/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def parse(cls, parser: Parser) -> FuncOp:
region,
extra_attrs,
arg_attrs,
res_attrs,
) = parse_func_op_like(
parser, reserved_attr_names=("sym_name", "function_type", "sym_visibility")
)
Expand All @@ -138,6 +139,7 @@ def parse(cls, parser: Parser) -> FuncOp:
region=region,
visibility=visibility,
arg_attrs=arg_attrs,
res_attrs=res_attrs,
)
if extra_attrs is not None:
func.attributes |= extra_attrs.data
Expand All @@ -155,6 +157,7 @@ def print(self, printer: Printer):
self.body,
self.attributes,
arg_attrs=self.arg_attrs,
res_attrs=self.res_attrs,
reserved_attr_names=(
"sym_name",
"function_type",
Expand Down
14 changes: 5 additions & 9 deletions xdsl/dialects/riscv_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,11 @@ def __init__(
@classmethod
def parse(cls, parser: Parser) -> FuncOp:
visibility = parser.parse_optional_visibility_keyword()
(
name,
input_types,
return_types,
region,
extra_attrs,
arg_attrs,
) = parse_func_op_like(
parser, reserved_attr_names=("sym_name", "function_type", "sym_visibility")
(name, input_types, return_types, region, extra_attrs, arg_attrs, _) = (
parse_func_op_like(
parser,
reserved_attr_names=("sym_name", "function_type", "sym_visibility"),
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please raise NotImplementedError if there are any res_attrs here and in other places?

)
if arg_attrs:
raise NotImplementedError("arg_attrs not implemented in riscv_func")
Expand Down
60 changes: 52 additions & 8 deletions xdsl/dialects/utils/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def print_func_op_like(
attributes: dict[str, Attribute],
*,
arg_attrs: ArrayAttr[DictionaryAttr] | None = None,
res_attrs: ArrayAttr[DictionaryAttr] | None = None,
reserved_attr_names: Sequence[str],
):
printer.print(f" @{sym_name.data}")
Expand All @@ -62,7 +63,15 @@ def print_func_op_like(
printer.print("-> ")
if len(function_type.outputs) > 1:
printer.print("(")
printer.print_list(function_type.outputs, printer.print_attribute)
if res_attrs is not None:
printer.print_list(
zip(function_type.outputs, res_attrs),
lambda arg_with_attrs: print_func_output(
printer, arg_with_attrs[0], arg_with_attrs[1]
),
)
else:
printer.print_list(function_type.outputs, printer.print_attribute)
if len(function_type.outputs) > 1:
printer.print(")")
printer.print(" ")
Expand All @@ -85,9 +94,10 @@ def parse_func_op_like(
Region,
DictionaryAttr | None,
ArrayAttr[DictionaryAttr] | None,
ArrayAttr[DictionaryAttr] | None,
]:
"""
Returns the function name, argument types, return types, body, extra args, and arg_attrs.
Returns the function name, argument types, return types, body, extra args, arg_attrs and res_attrs.
"""
# Parse function name
name = parser.parse_symbol_name().data
Expand All @@ -103,6 +113,13 @@ def parse_fun_input() -> Attribute | tuple[Parser.Argument, dict[str, Attribute]
ret = (arg, arg_attr_dict)
return ret

def parse_fun_output() -> tuple[Attribute, dict[str, Attribute]]:
arg_type = parser.parse_optional_type()
if arg_type is None:
parser.raise_error("Return type should be specified")
arg_attr_dict = parser.parse_optional_dictionary_attr_dict()
return (arg_type, arg_attr_dict)

# Parse function arguments
args = parser.parse_comma_separated_list(
parser.Delimiter.PAREN,
Expand Down Expand Up @@ -135,14 +152,25 @@ def parse_fun_input() -> Attribute | tuple[Parser.Argument, dict[str, Attribute]
arg_attrs = None

# Parse return type
return_types: list[Attribute] = []
res_attrs_raw: list[dict[str, Attribute]] | None = []
if parser.parse_optional_punctuation("->"):
return_types = parser.parse_optional_comma_separated_list(
parser.Delimiter.PAREN, parser.parse_type
return_attributes = parser.parse_optional_comma_separated_list(
parser.Delimiter.PAREN, parse_fun_output
)
if return_types is None:
return_types = [parser.parse_type()]
if return_attributes is None:
# output attributes are supported only if return results are enclosed in brackets (...)
return_types, res_attrs_raw = [parser.parse_type()], None
else:
return_types, res_attrs_raw = (
[el[0] for el in return_attributes],
[el[1] for el in return_attributes],
)

if res_attrs_raw is not None and any(res_attrs_raw):
res_attrs = ArrayAttr(DictionaryAttr(attrs) for attrs in res_attrs_raw)
else:
return_types = []
res_attrs = None

extra_attributes = parser.parse_optional_attr_dict_with_keyword(reserved_attr_names)

Expand All @@ -151,7 +179,15 @@ def parse_fun_input() -> Attribute | tuple[Parser.Argument, dict[str, Attribute]
if region is None:
region = Region()

return name, input_types, return_types, region, extra_attributes, arg_attrs
return (
name,
input_types,
return_types,
region,
extra_attributes,
arg_attrs,
res_attrs,
)


def print_func_argument(
Expand All @@ -162,6 +198,14 @@ def print_func_argument(
printer.print_op_attributes(attrs.data)


def print_func_output(
printer: Printer, out_type: Attribute, attrs: DictionaryAttr | None
):
printer.print_attribute(out_type)
if attrs is not None and attrs.data:
printer.print_op_attributes(attrs.data)


def print_assignment(printer: Printer, arg: BlockArgument, val: SSAValue):
printer.print_block_argument(arg, print_type=False)
printer.print_string(" = ")
Expand Down
Loading