Skip to content

Commit

Permalink
pyright: fix custom dialects snax, snax_stream, tsl (#326)
Browse files Browse the repository at this point in the history
* pyright: fix custom dialects snax, snax_stream, tsl

* undo parser change

* undo breaking change
  • Loading branch information
jorendumoulin authored Jan 3, 2025
1 parent 43059c9 commit dd5510b
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 40 deletions.
25 changes: 14 additions & 11 deletions compiler/dialects/snax.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
IndexType,
IntegerAttr,
IntegerType,
MemrefLayoutAttr,
MemRefType,
NoneAttr,
UnrankedMemrefType,
Expand Down Expand Up @@ -64,8 +65,8 @@ class LayoutCast(IRDLOperation):

name = "snax.layout_cast"

source = operand_def(MemRefType[Attribute] | UnrankedMemrefType[Attribute])
dest = result_def(MemRefType[Attribute] | UnrankedMemrefType[Attribute])
source = operand_def(MemRefType)
dest = result_def(MemRefType)

def __init__(
self,
Expand All @@ -77,14 +78,16 @@ def __init__(
@staticmethod
def from_type_and_target_layout(
source: SSAValue | Operation,
layout: Attribute,
layout: MemrefLayoutAttr,
) -> LayoutCast:
assert isinstance(source.type, MemRefType)
source = SSAValue.get(source)
assert isinstance(source.type, Attribute)
source_type = cast(MemRefType[Attribute], source.type)
dest = MemRefType(
source.type.get_element_type(),
shape=source.type.get_shape(),
source_type.get_element_type(),
source_type.get_shape(),
layout=layout,
memory_space=source.type.memory_space,
memory_space=source_type.memory_space,
)
return LayoutCast(source, dest)

Expand Down Expand Up @@ -117,8 +120,8 @@ class Alloc(IRDLOperation):

name = "snax.alloc"

size: Operand = operand_def(IntegerType | IndexType)
shapes: VarOperand = var_operand_def(IntegerType | IndexType)
size: Operand = operand_def(IndexType)
shapes: VarOperand = var_operand_def(IndexType)
result: OpResult = result_def(LLVMStructType)
memory_space: Attribute | None = opt_prop_def(Attribute)
alignment: AnyIntegerAttr | None = opt_prop_def(AnyIntegerAttr)
Expand All @@ -129,7 +132,7 @@ def __init__(
size: SSAValue | Operation,
shapes: list[SSAValue | Operation],
memory_space: Attribute = NoneAttr(),
alignment: AnyIntegerAttr = None,
alignment: AnyIntegerAttr | None = None,
integer_type: IntegerType = i32,
):
# output type is llvm struct memref descriptor
Expand Down Expand Up @@ -184,7 +187,7 @@ def parse_parameter(cls, parser: AttrParser) -> StreamerConfiguration:
parser.parse_punctuation("[")

# Determine streamer options
opts = []
opts: Sequence[StreamerOpts] = []
if parser.parse_optional_keyword("opts"):
parser.parse_punctuation("=")
while not parser.parse_optional_punctuation(","):
Expand Down
2 changes: 1 addition & 1 deletion compiler/dialects/snax_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
parameters: Sequence[Attribute] = []
for arg in (upper_bounds, temporal_strides, spatial_strides):
if not isinstance(arg, ArrayAttr):
arg = ArrayAttr([IntAttr(x) if isinstance(x, int) else x for x in arg])
arg = ArrayAttr([IntAttr(x) for x in arg])
parameters.append(arg)
super().__init__(parameters)

Expand Down
33 changes: 15 additions & 18 deletions compiler/dialects/tsl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from math import prod
from typing import cast

from xdsl.dialects.arith import ConstantOp, DivUIOp, MuliOp
from xdsl.dialects.builtin import (
Expand All @@ -11,7 +12,7 @@
StridedLayoutAttr,
)
from xdsl.dialects.memref import DimOp, ExtractStridedMetaDataOp
from xdsl.ir import Data, Dialect, Operation, SSAValue
from xdsl.ir import Attribute, Data, Dialect, Operation, SSAValue
from xdsl.ir.affine import AffineConstantExpr, AffineDimExpr, AffineMap
from xdsl.irdl import (
irdl_attr_definition,
Expand All @@ -32,7 +33,7 @@ class TiledStridedLayoutAttr(MemrefLayoutAttr, Data[TiledStridedLayout]):
@classmethod
def parse_parameter(cls, parser: AttrParser) -> TiledStridedLayout:
with parser.in_angle_brackets():
tslparser = TSLParser(parser._parser_state)
tslparser = TSLParser(parser._parser_state) # pyright: ignore
return tslparser.parse()

def print_parameter(self, printer: Printer) -> None:
Expand All @@ -42,10 +43,6 @@ def get_affine_map(self) -> AffineMap:
if self.data.is_dynamic():
raise NotImplementedError("Dynamic case is not implemented yet!")

# TODO: the affine map should result in element offset, not byte offset
# i will probably transition the tsl definition to element offset
# as well, to make everything more convenient

result = AffineConstantExpr(0)
for dim in range(self.data.dimension()):
max_depth = self.data.tstrides[dim].depth()
Expand Down Expand Up @@ -86,15 +83,15 @@ def get_bound_ops(
"""

result: list[Operation] = []
result_mapping: dict[(int, int), Operation] = {}
result_mapping: dict[tuple[int, int], Operation] = {}

tsl = self.data

if isinstance(memref_op_or_shapes, SSAValue | Operation):
# if the argument passed is a memref, generate shape operation
# list by using the dim operation
memref = memref_op_or_shapes
shapes = []
shapes: list[Operation] = []
for dim in range(tsl.dimension()):
dim_index_op = ConstantOp.from_int_and_width(dim, IndexType())
dim_op = DimOp.from_source_and_index(memref, dim_index_op)
Expand Down Expand Up @@ -134,6 +131,7 @@ def get_bound_ops(
# inner tile depths are all static by definition of TSL
for depth in range(1, tsl.tstrides[dim].depth()):
stride = tsl.get_stride(dim, depth)
assert stride.bound is not None
bound_op = ConstantOp.from_int_and_width(stride.bound, IndexType())
result.append(bound_op)
result_mapping[(dim, depth)] = bound_op
Expand Down Expand Up @@ -174,14 +172,14 @@ def get_step_ops(
# In this case, if there are dynamic strides, we cannot perform
# the TSL contiguity assumptions. Instead, dynamic strides are
# fetched from the extract strided metadata operation.
if (
memref_op
and isinstance(memref_op.type, MemRefType)
and isinstance(memref_op.type.layout, StridedLayoutAttr)
if memref_op and isinstance(
(memref_type := cast(MemRefType[Attribute], memref_op.type)).layout,
StridedLayoutAttr,
):
metadata_op = ExtractStridedMetaDataOp(memref_op)
assert isinstance(memref_type.element_type, FixedBitwidthType)
element_size_op = ConstantOp.from_int_and_width(
memref_op.type.element_type.width.data // 8, IndexType()
memref_type.element_type.size, IndexType()
)
result.extend([metadata_op, element_size_op])
for dim in range(tsl.dimension()):
Expand All @@ -193,12 +191,11 @@ def get_step_ops(

# optional bytes correction
if in_bytes:
assert memref_op
assert isinstance(memref_op.type, MemRefType)
assert isinstance(memref_op.type.element_type, FixedBitwidthType)
el_bytes = memref_op.type.element_type.size
assert memref_op is not None
memref_type = cast(MemRefType[Attribute], memref_op.type)
assert isinstance(memref_type.element_type, FixedBitwidthType)
el_bytes = memref_type.element_type.size
else:
# else use 1 such that 1 element = 1 byte
el_bytes = 1

# to handle the dynamic case, we must first find the largest
Expand Down
2 changes: 1 addition & 1 deletion pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,6 @@ typeCheckingMode = "strict"
"compiler/accelerators/snax_gemmx.py",
"compiler/accelerators/snax_hwpe_mult.py",
"compiler/dialects/accfg.py",
"compiler/dialects/snax.py",
"compiler/dialects/snax_stream.py",
"compiler/dialects/tsl.py",
"compiler/inference/dataflow.py",
"compiler/inference/helpers.py",
"compiler/inference/scoped_setups.py",
Expand Down
12 changes: 6 additions & 6 deletions tests/filecheck/dialects/snax/snax_invalid.mlir
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
// RUN: XDSL_VERIFY_DIAG

"builtin.module"() ({
%0 = arith.constant 45 : i32
%1 = "snax.alloc"(%0) <{"memory_space" = 1 : i32, "alignment" = 64 : i32}> : (i32) -> !llvm.struct<(i32)>
%0 = arith.constant 45 : index
%1 = "snax.alloc"(%0) <{"memory_space" = 1 : i32, "alignment" = 64 : i32}> : (index) -> !llvm.struct<(i32)>
}) : () -> ()

// CHECK: Operation does not verify: Invalid Memref Descriptor: Expected first element to be LLVMPointerType

// -----

"builtin.module"() ({
%0 = arith.constant 45 : i32
%1 = "snax.alloc"(%0) <{"memory_space" = 2 : i32, "alignment" = 64 : i32}> : (i32) -> !llvm.struct<(!llvm.ptr, !llvm.ptr, !llvm.array<2 x i32>, !llvm.array<2 x i32>)>
%0 = arith.constant 45 : index
%1 = "snax.alloc"(%0) <{"memory_space" = 2 : i32, "alignment" = 64 : i32}> : (index) -> !llvm.struct<(!llvm.ptr, !llvm.ptr, !llvm.array<2 x i32>, !llvm.array<2 x i32>)>
}) : () -> ()

// CHECK: Operation does not verify: Invalid Memref Descriptor: Expected third element to be IntegerType

// -----

"builtin.module"() ({
%0 = arith.constant 45 : i32
%1 = "snax.alloc"(%0) <{"memory_space" = 2 : i32, "alignment" = 64 : i32}> : (i32) -> !llvm.struct<(!llvm.ptr, !llvm.ptr, i32, !llvm.array<1 x i32>, !llvm.array<2 x i32>)>
%0 = arith.constant 45 : index
%1 = "snax.alloc"(%0) <{"memory_space" = 2 : i32, "alignment" = 64 : i32}> : (index) -> !llvm.struct<(!llvm.ptr, !llvm.ptr, i32, !llvm.array<1 x i32>, !llvm.array<2 x i32>)>
}) : () -> ()

// CHECK: Operation does not verify: Invalid Memref Descriptor: Expected shape and strides to have the same dimension

0 comments on commit dd5510b

Please sign in to comment.