From 546946edf695ea76a43c46ba6a3b0da70537ec96 Mon Sep 17 00:00:00 2001 From: Joren Dumoulin Date: Fri, 17 Jan 2025 10:28:43 +0100 Subject: [PATCH] pyright: resolve memref_descriptor and test_snax (#335) * pyright: resolve memref_descriptor and test_snax * update pixi lockfile --- compiler/util/memref_descriptor.py | 36 ++++++++++++++---------------- pixi.lock | 2 +- pyproject.toml | 2 -- tests/dialects/test_snax.py | 9 +++++--- 4 files changed, 24 insertions(+), 25 deletions(-) diff --git a/compiler/util/memref_descriptor.py b/compiler/util/memref_descriptor.py index 7cd28f4a..f6304652 100644 --- a/compiler/util/memref_descriptor.py +++ b/compiler/util/memref_descriptor.py @@ -1,6 +1,6 @@ -from xdsl.dialects.builtin import ArrayAttr, IntegerType +from xdsl.dialects.builtin import IntegerType, MemRefType from xdsl.dialects.llvm import LLVMArrayType, LLVMPointerType, LLVMStructType -from xdsl.dialects.memref import MemRefType +from xdsl.ir import Attribute from xdsl.utils.exceptions import VerifyException # this file contains useful helper functions to work with @@ -58,7 +58,7 @@ def from_rank_and_integer_type( @classmethod def from_memref_type( - cls, memref_type: MemRefType, integer_type: IntegerType + cls, memref_type: MemRefType[Attribute], integer_type: IntegerType ) -> "LLVMMemrefDescriptor": """ Create an LLVMMemrefDescriptor from a MemRefType. @@ -71,9 +71,10 @@ def from_memref_type( LLVMMemrefDescriptor: The created descriptor. """ - return cls.from_rank_and_integer_type( - memref_type.get_num_dims(), memref_type.get_element_type() - ) + el_type = memref_type.get_element_type() + assert isinstance(el_type, IntegerType) + + return cls.from_rank_and_integer_type(memref_type.get_num_dims(), el_type) def verify(self) -> None: """ @@ -83,38 +84,35 @@ def verify(self) -> None: VerifyException: If the memref descriptor is invalid. """ - def raise_exception(message: str) -> None: - raise VerifyException("Invalid Memref Descriptor: " + message) - - if not isinstance(self.descriptor.types, ArrayAttr): - raise VerifyException("Expected result type to have an ArrayAttr") + def exception(message: str) -> VerifyException: + return VerifyException("Invalid Memref Descriptor: " + message) type_iter = iter(self.descriptor.types.data) if not isinstance(next(type_iter), LLVMPointerType): - raise_exception("Expected first element to be LLVMPointerType") + raise exception("Expected first element to be LLVMPointerType") if not isinstance(next(type_iter), LLVMPointerType): - raise_exception("Expected second element to be LLVMPointerType") + raise exception("Expected second element to be LLVMPointerType") if not isinstance(next(type_iter), IntegerType): - raise_exception("Expected third element to be IntegerType") + raise exception("Expected third element to be IntegerType") shape = next(type_iter) if not isinstance(shape, LLVMArrayType): - raise_exception("Expected fourth element to be LLVMArrayType") + raise exception("Expected fourth element to be LLVMArrayType") if not isinstance(shape.type, IntegerType): - raise_exception( + raise exception( "Expected fourth element to be LLVMArrayType of IntegerType" ) strides = next(type_iter) if not isinstance(strides, LLVMArrayType): - raise_exception("Expected fifth element to be LLVMArrayType") + raise exception("Expected fifth element to be LLVMArrayType") if not isinstance(strides.type, IntegerType): - raise_exception("Expected fifth element to be LLVMArrayType of IntegerType") + raise exception("Expected fifth element to be LLVMArrayType of IntegerType") if not strides.size.data == shape.size.data: - raise_exception("Expected shape and strides to have the same dimension") + raise exception("Expected shape and strides to have the same dimension") diff --git a/pixi.lock b/pixi.lock index 4cf22387..40a51941 100644 --- a/pixi.lock +++ b/pixi.lock @@ -2095,7 +2095,7 @@ packages: - pypi: . name: snax-mlir version: 0.2.2 - sha256: 55b922fb3119a196e110ee98278db73d1ab26ef7c3e03cbd497abac0c50d7bfa + sha256: aa69f080b757efe0802f2fd31707c3bd19ade9463d18e6dc1028f7c15643b847 requires_dist: - xdsl @ git+https://github.com/xdslproject/xdsl.git@d72f46d92ec4b03ae05b91e70d75f93735e94393 - pre-commit ; extra == 'dev' diff --git a/pyproject.toml b/pyproject.toml index 39b02239..1f04f9e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,9 +100,7 @@ typeCheckingMode = "strict" "compiler/transforms/snax_lower_mcycle.py", "compiler/transforms/snax_to_func.py", "compiler/transforms/test_remove_memref_copy.py", - "compiler/util/memref_descriptor.py", "tests/benchmark/test_snax_benchmark.py", - "tests/dialects/test_snax.py", "tests/inference/test_accfg_state_tracing.py", "tests/util/", ] diff --git a/tests/dialects/test_snax.py b/tests/dialects/test_snax.py index 4450d936..a867463f 100644 --- a/tests/dialects/test_snax.py +++ b/tests/dialects/test_snax.py @@ -1,9 +1,11 @@ import pytest from xdsl.dialects import builtin -from xdsl.dialects.builtin import ArrayAttr, StridedLayoutAttr, i32, i64 +from xdsl.dialects.builtin import ArrayAttr, MemRefType, StridedLayoutAttr, i32, i64 from xdsl.dialects.llvm import LLVMStructType -from xdsl.dialects.memref import MemRefType +from xdsl.ir import Attribute +from xdsl.parser import StringAttr from xdsl.utils.exceptions import VerifyException +from xdsl.utils.hints import isa from xdsl.utils.test_value import TestSSAValue from compiler.dialects.snax import Alloc, LayoutCast @@ -62,7 +64,7 @@ def test_memref_memory_space_cast(): memory_layout_cast = LayoutCast.from_type_and_target_layout(source_ssa, layout_2) assert memory_layout_cast.source is source_ssa - assert isinstance(memory_layout_cast.dest.type, MemRefType) + assert isa(memory_layout_cast.dest.type, MemRefType[Attribute]) assert memory_layout_cast.dest.type.layout is layout_2 @@ -73,6 +75,7 @@ def test_snax_alloc(): alloc_a = Alloc(dim, size, shape, memory_space=builtin.StringAttr("L1")) assert alloc_a.size is size + assert isinstance(alloc_a.memory_space, StringAttr) assert alloc_a.memory_space.data == "L1" assert isinstance(alloc_a.result.type, LLVMStructType)