Skip to content

Commit

Permalink
pyright: resolve memref_descriptor and test_snax (#335)
Browse files Browse the repository at this point in the history
* pyright: resolve memref_descriptor and test_snax

* update pixi lockfile
  • Loading branch information
jorendumoulin authored Jan 17, 2025
1 parent 3ce8fdf commit 546946e
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 25 deletions.
36 changes: 17 additions & 19 deletions compiler/util/memref_descriptor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from xdsl.dialects.builtin import ArrayAttr, IntegerType
from xdsl.dialects.builtin import IntegerType, MemRefType
from xdsl.dialects.llvm import LLVMArrayType, LLVMPointerType, LLVMStructType
from xdsl.dialects.memref import MemRefType
from xdsl.ir import Attribute
from xdsl.utils.exceptions import VerifyException

# this file contains useful helper functions to work with
Expand Down Expand Up @@ -58,7 +58,7 @@ def from_rank_and_integer_type(

@classmethod
def from_memref_type(
cls, memref_type: MemRefType, integer_type: IntegerType
cls, memref_type: MemRefType[Attribute], integer_type: IntegerType
) -> "LLVMMemrefDescriptor":
"""
Create an LLVMMemrefDescriptor from a MemRefType.
Expand All @@ -71,9 +71,10 @@ def from_memref_type(
LLVMMemrefDescriptor: The created descriptor.
"""

return cls.from_rank_and_integer_type(
memref_type.get_num_dims(), memref_type.get_element_type()
)
el_type = memref_type.get_element_type()
assert isinstance(el_type, IntegerType)

return cls.from_rank_and_integer_type(memref_type.get_num_dims(), el_type)

def verify(self) -> None:
"""
Expand All @@ -83,38 +84,35 @@ def verify(self) -> None:
VerifyException: If the memref descriptor is invalid.
"""

def raise_exception(message: str) -> None:
raise VerifyException("Invalid Memref Descriptor: " + message)

if not isinstance(self.descriptor.types, ArrayAttr):
raise VerifyException("Expected result type to have an ArrayAttr")
def exception(message: str) -> VerifyException:
return VerifyException("Invalid Memref Descriptor: " + message)

type_iter = iter(self.descriptor.types.data)

if not isinstance(next(type_iter), LLVMPointerType):
raise_exception("Expected first element to be LLVMPointerType")
raise exception("Expected first element to be LLVMPointerType")

if not isinstance(next(type_iter), LLVMPointerType):
raise_exception("Expected second element to be LLVMPointerType")
raise exception("Expected second element to be LLVMPointerType")

if not isinstance(next(type_iter), IntegerType):
raise_exception("Expected third element to be IntegerType")
raise exception("Expected third element to be IntegerType")

shape = next(type_iter)
if not isinstance(shape, LLVMArrayType):
raise_exception("Expected fourth element to be LLVMArrayType")
raise exception("Expected fourth element to be LLVMArrayType")

if not isinstance(shape.type, IntegerType):
raise_exception(
raise exception(
"Expected fourth element to be LLVMArrayType of IntegerType"
)

strides = next(type_iter)
if not isinstance(strides, LLVMArrayType):
raise_exception("Expected fifth element to be LLVMArrayType")
raise exception("Expected fifth element to be LLVMArrayType")

if not isinstance(strides.type, IntegerType):
raise_exception("Expected fifth element to be LLVMArrayType of IntegerType")
raise exception("Expected fifth element to be LLVMArrayType of IntegerType")

if not strides.size.data == shape.size.data:
raise_exception("Expected shape and strides to have the same dimension")
raise exception("Expected shape and strides to have the same dimension")
2 changes: 1 addition & 1 deletion pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,7 @@ typeCheckingMode = "strict"
"compiler/transforms/snax_lower_mcycle.py",
"compiler/transforms/snax_to_func.py",
"compiler/transforms/test_remove_memref_copy.py",
"compiler/util/memref_descriptor.py",
"tests/benchmark/test_snax_benchmark.py",
"tests/dialects/test_snax.py",
"tests/inference/test_accfg_state_tracing.py",
"tests/util/",
]
Expand Down
9 changes: 6 additions & 3 deletions tests/dialects/test_snax.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import pytest
from xdsl.dialects import builtin
from xdsl.dialects.builtin import ArrayAttr, StridedLayoutAttr, i32, i64
from xdsl.dialects.builtin import ArrayAttr, MemRefType, StridedLayoutAttr, i32, i64
from xdsl.dialects.llvm import LLVMStructType
from xdsl.dialects.memref import MemRefType
from xdsl.ir import Attribute
from xdsl.parser import StringAttr
from xdsl.utils.exceptions import VerifyException
from xdsl.utils.hints import isa
from xdsl.utils.test_value import TestSSAValue

from compiler.dialects.snax import Alloc, LayoutCast
Expand Down Expand Up @@ -62,7 +64,7 @@ def test_memref_memory_space_cast():
memory_layout_cast = LayoutCast.from_type_and_target_layout(source_ssa, layout_2)

assert memory_layout_cast.source is source_ssa
assert isinstance(memory_layout_cast.dest.type, MemRefType)
assert isa(memory_layout_cast.dest.type, MemRefType[Attribute])
assert memory_layout_cast.dest.type.layout is layout_2


Expand All @@ -73,6 +75,7 @@ def test_snax_alloc():
alloc_a = Alloc(dim, size, shape, memory_space=builtin.StringAttr("L1"))

assert alloc_a.size is size
assert isinstance(alloc_a.memory_space, StringAttr)
assert alloc_a.memory_space.data == "L1"

assert isinstance(alloc_a.result.type, LLVMStructType)
Expand Down

0 comments on commit 546946e

Please sign in to comment.