diff --git a/compiler/ir/tsl/tiled_strided_layout.py b/compiler/ir/tsl/tiled_strided_layout.py index c14d4764..08c08d0a 100644 --- a/compiler/ir/tsl/tiled_strided_layout.py +++ b/compiler/ir/tsl/tiled_strided_layout.py @@ -4,6 +4,7 @@ from dataclasses import dataclass import numpy as np +from numpy._typing import NDArray from compiler.ir.tsl.stride import Stride from compiler.ir.tsl.tiled_stride import TiledStride @@ -76,14 +77,14 @@ def get_stride(self, dim: int, depth: int) -> Stride: the Tiled Strided Layout""" return self.tstrides[dim].strides[depth] - def all_values(self) -> np.ndarray: + def all_values(self) -> NDArray[np.int_]: """ Returns a numpy array containing all the elements in the iteration space. """ result = np.array([0]) for _, _, stride in self: - next_stride = np.array(stride.all_values()) + next_stride = np.array(stride.all_values(), dtype=np.int_) # for every stride, add a dimension and broadcast sum result = np.squeeze( np.expand_dims(result, -1) + np.expand_dims(next_stride, 0) diff --git a/compiler/parser/tsl_parser.py b/compiler/parser/tsl_parser.py index 623e7719..fb00fdad 100644 --- a/compiler/parser/tsl_parser.py +++ b/compiler/parser/tsl_parser.py @@ -19,23 +19,23 @@ def _parse_int_or_question(self, context_msg: str = "") -> int | None: return v self.raise_error("Expected an integer literal or `?`" + context_msg) - def _parse_step(self) -> list[int]: + def _parse_step(self) -> list[int | None]: """ steps ::== `(` steps (`,` steps)* `)` """ self._parse_token(Token.Kind.L_PAREN, "Expected opening bracket") - steps: list[int] = [] + steps: list[int | None] = [] while not self._parse_optional_token(Token.Kind.R_PAREN): steps.append(self._parse_int_or_question()) self._parse_optional_token(Token.Kind.COMMA) return steps - def _parse_bound(self) -> list[int]: + def _parse_bound(self) -> list[int | None]: """ bounds ::== `[` bound (`,` bound)* `]` """ self._parse_token(Token.Kind.L_SQUARE, "Expected opening bracket") - bounds: list[int] = [] + bounds: list[int | None] = [] while not self._parse_optional_token(Token.Kind.R_SQUARE): bounds.append(self._parse_int_or_question()) self._parse_optional_token(Token.Kind.COMMA) @@ -49,7 +49,9 @@ def _parse_tiled_stride(self) -> TiledStride: self._parse_token(Token.Kind.ARROW, "Expected arrow") steps = self._parse_step() if len(steps) != len(bounds): - raise ParseError("Expected same number of steps and bounds") + raise ParseError( + self._current_token.span, "Expected same number of steps and bounds" + ) # construct the tiledstrides return TiledStride([Stride(step, bound) for step, bound in zip(steps, bounds)]) @@ -57,7 +59,7 @@ def parse(self) -> TiledStridedLayout: """ tsl ::= tiled-stride (`,` tiled-stride)*` (, offset: ` offset)? """ - tstrides = [] + tstrides: list[TiledStride] = [] offset = 0 while True: if self._current_token.kind == Token.Kind.GREATER: diff --git a/pixi.lock b/pixi.lock index d828a424..bb7ce26d 100644 --- a/pixi.lock +++ b/pixi.lock @@ -2095,7 +2095,7 @@ packages: - pypi: . name: snax-mlir version: 0.2.2 - sha256: f5d2084d3c5c46a36d9e8ef168f5620c6b73f31f8c932574fb08737bfa74abe8 + sha256: f28a042de384b0631ba9d16ad396aafb8568bf11d709f4cd6053c9644121b76f requires_dist: - xdsl @ git+https://github.com/xdslproject/xdsl.git@d72f46d92ec4b03ae05b91e70d75f93735e94393 - pre-commit ; extra == 'dev' diff --git a/pyproject.toml b/pyproject.toml index 2a8aa733..87e0d0b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,8 +91,6 @@ typeCheckingMode = "strict" "compiler/inference/helpers.py", "compiler/inference/scoped_setups.py", "compiler/inference/trace_acc_state.py", - "compiler/ir/tsl/tiled_strided_layout.py", - "compiler/parser/tsl_parser.py", "compiler/transforms/accfg_dedup.py", "compiler/transforms/clear_memory_space.py", "compiler/transforms/convert_linalg_to_accfg.py", @@ -114,11 +112,7 @@ typeCheckingMode = "strict" "compiler/util/memref_descriptor.py", "tests/benchmark/test_snax_benchmark.py", "tests/dialects/test_snax.py", - "tests/dialects/test_tsl.py", "tests/inference/test_accfg_state_tracing.py", - "tests/ir/tsl/test_stride.py", - "tests/ir/tsl/test_tiled_stride.py", - "tests/ir/tsl/test_tiled_strided_layout.py", "tests/util/", ] diff --git a/tests/dialects/test_tsl.py b/tests/dialects/test_tsl.py index f36d0c43..d04be87a 100644 --- a/tests/dialects/test_tsl.py +++ b/tests/dialects/test_tsl.py @@ -27,13 +27,13 @@ def example_tsl_attr(): return tsl_attr -def test_tsl_attr_constructor(example_tsl_attr): +def test_tsl_attr_constructor(example_tsl_attr: TiledStridedLayoutAttr): tsl = example_tsl_attr assert isinstance(tsl, TiledStridedLayoutAttr) assert isinstance(tsl.data, TiledStridedLayout) -def test_tsl_attr_get_affine(example_tsl_attr): +def test_tsl_attr_get_affine(example_tsl_attr: TiledStridedLayoutAttr): tsl = example_tsl_attr map = canonicalize_map(tsl.get_affine_map()) assert map == canonicalize_map( diff --git a/tests/ir/tsl/test_stride.py b/tests/ir/tsl/test_stride.py index 0ff3d053..781c5e1c 100644 --- a/tests/ir/tsl/test_stride.py +++ b/tests/ir/tsl/test_stride.py @@ -8,7 +8,7 @@ def example_strides(): return (Stride(1, 4), Stride(4, 6), Stride(24, 2), Stride(None, None)) -def test_stride_constructor(example_strides): +def test_stride_constructor(example_strides: tuple[Stride, ...]): stride1, stride2, _, dynamic_stride = example_strides assert stride1.step == 1 assert stride1.bound == 4 @@ -18,7 +18,7 @@ def test_stride_constructor(example_strides): assert dynamic_stride.bound is None -def test_stride_all_values(example_strides): +def test_stride_all_values(example_strides: tuple[Stride, ...]): stride1, stride2, _, dynamic_stride = example_strides assert stride1.all_values() == [0, 1, 2, 3] assert stride2.all_values() == [0, 4, 8, 12, 16, 20] @@ -26,7 +26,7 @@ def test_stride_all_values(example_strides): dynamic_stride.all_values() -def test_stride_str(example_strides): +def test_stride_str(example_strides: tuple[Stride, ...]): stride1, stride2, stride3, dynamic_stride = example_strides assert str(stride1) == "4 -> 1" assert str(stride2) == "6 -> 4" diff --git a/tests/ir/tsl/test_tiled_stride.py b/tests/ir/tsl/test_tiled_stride.py index ddd49413..8f3bbfee 100644 --- a/tests/ir/tsl/test_tiled_stride.py +++ b/tests/ir/tsl/test_tiled_stride.py @@ -10,7 +10,7 @@ def example_strides(): @pytest.fixture() -def example_tiled_strides(example_strides): +def example_tiled_strides(example_strides: tuple[Stride, ...]): stride1, stride2, stride3, dynamic_stride = example_strides tiledStride1 = TiledStride([stride2, stride1]) tiledStride2 = TiledStride([stride3, stride2, stride1]) @@ -18,7 +18,9 @@ def example_tiled_strides(example_strides): return tiledStride1, tiledStride2, tiledStride3 -def test_tiled_stride_constructor(example_strides, example_tiled_strides): +def test_tiled_stride_constructor( + example_strides: tuple[Stride, ...], example_tiled_strides: tuple[TiledStride, ...] +): stride1, stride2, stride3, _ = example_strides tiledStride1, tiledStride2, _ = example_tiled_strides assert tiledStride1.strides[0] == stride2 @@ -38,21 +40,23 @@ def test_tiled_stride_from_stride(): assert tiledStride2.strides[2] == Stride(24, 4) -def test_tiled_stride_depth(example_tiled_strides): +def test_tiled_stride_depth(example_tiled_strides: tuple[TiledStride, ...]): tiledStride1, tiledStride2, tiledStride3 = example_tiled_strides assert tiledStride1.depth() == 2 assert tiledStride2.depth() == 3 assert tiledStride3.depth() == 2 -def test_tiled_stride_str(example_tiled_strides): +def test_tiled_stride_str(example_tiled_strides: tuple[TiledStride, ...]): tiledStride1, tiledStride2, tiledStride3 = example_tiled_strides assert str(tiledStride1) == "[6, 4] -> (4, 1)" assert str(tiledStride2) == "[2, 6, 4] -> (24, 4, 1)" assert str(tiledStride3) == "[?, 4] -> (?, 1)" -def test_tiled_stride_iter(example_strides, example_tiled_strides): +def test_tiled_stride_iter( + example_strides: tuple[Stride, ...], example_tiled_strides: tuple[TiledStride, ...] +): stride1, stride2, stride3, _ = example_strides strides = [stride3, stride2, stride1] @@ -64,7 +68,7 @@ def test_tiled_stride_iter(example_strides, example_tiled_strides): assert stride == strides[depth] -def test_tiled_stride_tile_bounds(example_tiled_strides): +def test_tiled_stride_tile_bounds(example_tiled_strides: tuple[TiledStride, ...]): tiledStride1, tiledStride2, tiledStride3 = example_tiled_strides assert tiledStride1.tile_bounds() == [6, 4] assert tiledStride2.tile_bounds() == [2, 6, 4] diff --git a/tests/ir/tsl/test_tiled_strided_layout.py b/tests/ir/tsl/test_tiled_strided_layout.py index 1f8f8389..dbf25d32 100644 --- a/tests/ir/tsl/test_tiled_strided_layout.py +++ b/tests/ir/tsl/test_tiled_strided_layout.py @@ -30,7 +30,7 @@ def example_tsl(): return tsl, tsl2 -def test_tsl_constructor(example_tsl): +def test_tsl_constructor(example_tsl: tuple[TiledStridedLayout, ...]): tsl, _ = example_tsl assert isinstance(tsl.tstrides[0], TiledStride) assert isinstance(tsl.tstrides[1], TiledStride) @@ -38,7 +38,7 @@ def test_tsl_constructor(example_tsl): def test_tsl_from_strides(): strides = [None, 1] - tile_bounds = [[16, 4], [16, 4]] + tile_bounds: list[list[int | None]] = [[16, 4], [16, 4]] tsl_constructor = TiledStridedLayout( [ TiledStride([Stride(None, 16), Stride(None, 4)]), @@ -49,13 +49,13 @@ def test_tsl_from_strides(): assert tsl_constructor == tsl_from_strides -def test_tsl_str(example_tsl): +def test_tsl_str(example_tsl: tuple[TiledStridedLayout, ...]): tsl, tsl2 = example_tsl assert str(tsl) == "[2, 4] -> (32, 4), [2, 4] -> (16, 1), offset: 5" assert str(tsl2) == "[2, 4] -> (32, 4), [?, 4] -> (?, 1), offset: 7" -def test_tsl_iter(example_tsl): +def test_tsl_iter(example_tsl: tuple[TiledStridedLayout, ...]): tsl, _ = example_tsl count = 0 for dim, depth, stride in tsl: @@ -67,19 +67,19 @@ def test_tsl_iter(example_tsl): assert count == tsl.dimension() * tsl.tstrides[0].depth() -def test_tsl_all_values(example_tsl): +def test_tsl_all_values(example_tsl: tuple[TiledStridedLayout, ...]): tsl, tsl2 = example_tsl assert set(tsl.all_values()) == set(range(64)) with pytest.raises(ValueError): tsl2.all_values() -def test_tsl_tile_bounds(example_tsl): +def test_tsl_tile_bounds(example_tsl: tuple[TiledStridedLayout, ...]): tsl, _ = example_tsl assert tsl.tile_bounds() == [[2, 4], [2, 4]] -def test_tsl_self_overlaps(example_tsl): +def test_tsl_self_overlaps(example_tsl: tuple[TiledStridedLayout, ...]): tsl, _ = example_tsl assert not tsl.self_overlaps() @@ -100,7 +100,7 @@ def test_tsl_self_overlaps(example_tsl): assert tsl2.self_overlaps() -def test_tsl_is_dense(example_tsl): +def test_tsl_is_dense(example_tsl: tuple[TiledStridedLayout, ...]): tsl, _ = example_tsl assert tsl.is_dense() @@ -121,7 +121,7 @@ def test_tsl_is_dense(example_tsl): assert not tsl2.is_dense() -def test_tsl_equal_tile_bounds(example_tsl): +def test_tsl_equal_tile_bounds(example_tsl: tuple[TiledStridedLayout, ...]): tsl, tsl2 = example_tsl assert tsl.equal_tile_bounds(tsl) assert not tsl.equal_tile_bounds(tsl2)