Skip to content

Commit

Permalink
core: Add attributes to func outputs (#3661)
Browse files Browse the repository at this point in the history
Mlir
[supports](https://mlir.llvm.org/docs/Dialects/Func/#funcfunc-funcfuncop)
attributes for output types. We need for cross compatibility with other
mlir projects.
  • Loading branch information
mamanain authored Dec 21, 2024
1 parent e5e9069 commit 955600d
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 45 deletions.
10 changes: 10 additions & 0 deletions tests/filecheck/dialects/func/func_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,14 @@ builtin.module {
// CHECK: func.func public @arg_attrs(%{{.*}} : tensor<8x8xf64> {"llvm.noalias"}, %{{.*}} : tensor<8x8xf64> {"llvm.noalias"}, %{{.*}} : tensor<8x8xf64> {"llvm.noalias"}) -> tensor<8x8xf64> {
// CHECK-NEXT: return %{{.*}} : tensor<8x8xf64>
// CHECK-NEXT: }

func.func @output_attributes() -> (f32 {dialect.a = 0 : i32}, f32 {dialect.b = 0 : i32, dialect.c = 1 : i64}) {
%r1, %r2 = "test.op"() : () -> (f32, f32)
return %r1, %r2 : f32, f32
}

// CHECK: func.func @output_attributes() -> (f32 {"dialect.a" = 0 : i32}, f32 {"dialect.b" = 0 : i32, "dialect.c" = 1 : i64}) {
// CHECK-NEXT: %r1, %r2 = "test.op"() : () -> (f32, f32)
// CHECK-NEXT: func.return %r1, %r2 : f32, f32
// CHECK-NEXT: }
}
10 changes: 10 additions & 0 deletions tests/filecheck/dialects/func/func_ops_generic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,13 @@
// CHECK-NEXT: ^0(%arg0 : tensor<8x8xf64>, %arg1 : tensor<8x8xf64>):
// CHECK-NEXT: "func.return"(%arg0, %arg1) : (tensor<8x8xf64>, tensor<8x8xf64>) -> ()
// CHECK-NEXT: }) : () -> ()

func.func @output_attributes() -> (f32 {dialect.a = 0 : i32}, f32 {dialect.b = 0 : i32, dialect.c = 1 : i64}) {
%r1, %r2 = "test.op"() : () -> (f32, f32)
return %r1, %r2 : f32, f32
}

// CHECK: "func.func"() <{"sym_name" = "output_attributes", "function_type" = () -> (f32, f32), "res_attrs" = [{"dialect.a" = 0 : i32}, {"dialect.b" = 0 : i32, "dialect.c" = 1 : i64}]}> ({
// CHECK-NEXT: %r1, %r2 = "test.op"() : () -> (f32, f32)
// CHECK-NEXT: "func.return"(%r1, %r2) : (f32, f32) -> ()
// CHECK-NEXT: }) : () -> ()
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,14 @@ builtin.module {
// CHECK: func.func public @arg_attrs(%{{.*}}: tensor<8x8xf64> {"llvm.noalias"}, %{{.*}}: tensor<8x8xf64> {"llvm.noalias"}, %{{.*}}: tensor<8x8xf64> {"llvm.noalias"}) -> tensor<8x8xf64> {
// CHECK-NEXT: func.return %{{.*}} : tensor<8x8xf64>
// CHECK-NEXT: }

func.func @output_attributes() -> (f32 {dialect.a = 0 : i32}, f32 {dialect.b = 0 : i32, dialect.c = 1 : i64}) {
%r1, %r2 = "test.op"() : () -> (f32, f32)
return %r1, %r2 : f32, f32
}

// CHECK: func.func @output_attributes() -> (f32 {"dialect.a" = 0 : i32}, f32 {"dialect.b" = 0 : i32, "dialect.c" = 1 : i64}) {
// CHECK-NEXT: %0, %1 = "test.op"() : () -> (f32, f32)
// CHECK-NEXT: func.return %0, %1 : f32, f32
// CHECK-NEXT: }
}
18 changes: 8 additions & 10 deletions xdsl/dialects/arm_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,18 +83,16 @@ def __init__(
@classmethod
def parse(cls, parser: Parser) -> FuncOp:
visibility = parser.parse_optional_visibility_keyword()
(
name,
input_types,
return_types,
region,
extra_attrs,
arg_attrs,
) = parse_func_op_like(
parser, reserved_attr_names=("sym_name", "function_type", "sym_visibility")
(name, input_types, return_types, region, extra_attrs, arg_attrs, res_attrs) = (
parse_func_op_like(
parser,
reserved_attr_names=("sym_name", "function_type", "sym_visibility"),
)
)
if arg_attrs:
raise NotImplementedError("arg_attrs not implemented in riscv_func")
raise NotImplementedError("arg_attrs not implemented in arm_func")
if res_attrs:
raise NotImplementedError("res_attrs not implemented in arm_func")
func = FuncOp(name, region, (input_types, return_types), visibility)
if extra_attrs is not None:
func.attributes |= extra_attrs.data
Expand Down
33 changes: 15 additions & 18 deletions xdsl/dialects/csl/csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,17 +780,16 @@ def verify_(self) -> None:

@classmethod
def parse(cls, parser: Parser) -> FuncOp:
(
name,
input_types,
return_types,
region,
extra_attrs,
arg_attrs,
) = parse_func_op_like(
parser, reserved_attr_names=("sym_name", "function_type", "sym_visibility")
(name, input_types, return_types, region, extra_attrs, arg_attrs, res_attrs) = (
parse_func_op_like(
parser,
reserved_attr_names=("sym_name", "function_type", "sym_visibility"),
)
)

if res_attrs:
raise NotImplementedError("res_attrs not implemented in csl FuncOp")

assert (
len(return_types) <= 1
), f"{cls.name} can't have more than one result type!"
Expand Down Expand Up @@ -890,16 +889,14 @@ def verify_(self) -> None:
@classmethod
def parse(cls, parser: Parser) -> TaskOp:
pos = parser.pos
(
name,
input_types,
return_types,
region,
extra_attrs,
arg_attrs,
) = parse_func_op_like(
parser, reserved_attr_names=("sym_name", "function_type", "sym_visibility")
(name, input_types, return_types, region, extra_attrs, arg_attrs, res_attrs) = (
parse_func_op_like(
parser,
reserved_attr_names=("sym_name", "function_type", "sym_visibility"),
)
)
if res_attrs:
raise NotImplementedError("res_attrs not implemented in csl TaskOp")
if (
extra_attrs is None
or "kind" not in extra_attrs.data
Expand Down
3 changes: 3 additions & 0 deletions xdsl/dialects/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def parse(cls, parser: Parser) -> FuncOp:
region,
extra_attrs,
arg_attrs,
res_attrs,
) = parse_func_op_like(
parser, reserved_attr_names=("sym_name", "function_type", "sym_visibility")
)
Expand All @@ -175,6 +176,7 @@ def parse(cls, parser: Parser) -> FuncOp:
region=region,
visibility=visibility,
arg_attrs=arg_attrs,
res_attrs=res_attrs,
)
if extra_attrs is not None:
func.attributes |= extra_attrs.data
Expand All @@ -192,6 +194,7 @@ def print(self, printer: Printer):
self.body,
self.attributes,
arg_attrs=self.arg_attrs,
res_attrs=self.res_attrs,
reserved_attr_names=(
"sym_name",
"function_type",
Expand Down
16 changes: 7 additions & 9 deletions xdsl/dialects/riscv_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,18 +174,16 @@ def __init__(
@classmethod
def parse(cls, parser: Parser) -> FuncOp:
visibility = parser.parse_optional_visibility_keyword()
(
name,
input_types,
return_types,
region,
extra_attrs,
arg_attrs,
) = parse_func_op_like(
parser, reserved_attr_names=("sym_name", "function_type", "sym_visibility")
(name, input_types, return_types, region, extra_attrs, arg_attrs, res_attrs) = (
parse_func_op_like(
parser,
reserved_attr_names=("sym_name", "function_type", "sym_visibility"),
)
)
if arg_attrs:
raise NotImplementedError("arg_attrs not implemented in riscv_func")
if res_attrs:
raise NotImplementedError("res_attrs not implemented in riscv_func")
func = FuncOp(name, region, (input_types, return_types), visibility)
if extra_attrs is not None:
func.attributes |= extra_attrs.data
Expand Down
60 changes: 52 additions & 8 deletions xdsl/dialects/utils/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def print_func_op_like(
attributes: dict[str, Attribute],
*,
arg_attrs: ArrayAttr[DictionaryAttr] | None = None,
res_attrs: ArrayAttr[DictionaryAttr] | None = None,
reserved_attr_names: Sequence[str],
):
printer.print(f" @{sym_name.data}")
Expand All @@ -62,7 +63,15 @@ def print_func_op_like(
printer.print("-> ")
if len(function_type.outputs) > 1:
printer.print("(")
printer.print_list(function_type.outputs, printer.print_attribute)
if res_attrs is not None:
printer.print_list(
zip(function_type.outputs, res_attrs),
lambda arg_with_attrs: print_func_output(
printer, arg_with_attrs[0], arg_with_attrs[1]
),
)
else:
printer.print_list(function_type.outputs, printer.print_attribute)
if len(function_type.outputs) > 1:
printer.print(")")
printer.print(" ")
Expand All @@ -85,9 +94,10 @@ def parse_func_op_like(
Region,
DictionaryAttr | None,
ArrayAttr[DictionaryAttr] | None,
ArrayAttr[DictionaryAttr] | None,
]:
"""
Returns the function name, argument types, return types, body, extra args, and arg_attrs.
Returns the function name, argument types, return types, body, extra args, arg_attrs and res_attrs.
"""
# Parse function name
name = parser.parse_symbol_name().data
Expand All @@ -103,6 +113,13 @@ def parse_fun_input() -> Attribute | tuple[Parser.Argument, dict[str, Attribute]
ret = (arg, arg_attr_dict)
return ret

def parse_fun_output() -> tuple[Attribute, dict[str, Attribute]]:
arg_type = parser.parse_optional_type()
if arg_type is None:
parser.raise_error("Return type should be specified")
arg_attr_dict = parser.parse_optional_dictionary_attr_dict()
return (arg_type, arg_attr_dict)

# Parse function arguments
args = parser.parse_comma_separated_list(
parser.Delimiter.PAREN,
Expand Down Expand Up @@ -135,14 +152,25 @@ def parse_fun_input() -> Attribute | tuple[Parser.Argument, dict[str, Attribute]
arg_attrs = None

# Parse return type
return_types: list[Attribute] = []
res_attrs_raw: list[dict[str, Attribute]] | None = []
if parser.parse_optional_punctuation("->"):
return_types = parser.parse_optional_comma_separated_list(
parser.Delimiter.PAREN, parser.parse_type
return_attributes = parser.parse_optional_comma_separated_list(
parser.Delimiter.PAREN, parse_fun_output
)
if return_types is None:
return_types = [parser.parse_type()]
if return_attributes is None:
# output attributes are supported only if return results are enclosed in brackets (...)
return_types, res_attrs_raw = [parser.parse_type()], None
else:
return_types, res_attrs_raw = (
[el[0] for el in return_attributes],
[el[1] for el in return_attributes],
)

if res_attrs_raw is not None and any(res_attrs_raw):
res_attrs = ArrayAttr(DictionaryAttr(attrs) for attrs in res_attrs_raw)
else:
return_types = []
res_attrs = None

extra_attributes = parser.parse_optional_attr_dict_with_keyword(reserved_attr_names)

Expand All @@ -151,7 +179,15 @@ def parse_fun_input() -> Attribute | tuple[Parser.Argument, dict[str, Attribute]
if region is None:
region = Region()

return name, input_types, return_types, region, extra_attributes, arg_attrs
return (
name,
input_types,
return_types,
region,
extra_attributes,
arg_attrs,
res_attrs,
)


def print_func_argument(
Expand All @@ -162,6 +198,14 @@ def print_func_argument(
printer.print_op_attributes(attrs.data)


def print_func_output(
printer: Printer, out_type: Attribute, attrs: DictionaryAttr | None
):
printer.print_attribute(out_type)
if attrs is not None and attrs.data:
printer.print_op_attributes(attrs.data)


def print_assignment(printer: Printer, arg: BlockArgument, val: SSAValue):
printer.print_block_argument(arg, print_type=False)
printer.print_string(" = ")
Expand Down

0 comments on commit 955600d

Please sign in to comment.