Skip to content

Commit

Permalink
pyright: fix tsl-related stuff (#327)
Browse files Browse the repository at this point in the history
  • Loading branch information
jorendumoulin authored Jan 6, 2025
1 parent dd5510b commit 480c18b
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 35 deletions.
5 changes: 3 additions & 2 deletions compiler/ir/tsl/tiled_strided_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dataclasses import dataclass

import numpy as np
from numpy._typing import NDArray

from compiler.ir.tsl.stride import Stride
from compiler.ir.tsl.tiled_stride import TiledStride
Expand Down Expand Up @@ -76,14 +77,14 @@ def get_stride(self, dim: int, depth: int) -> Stride:
the Tiled Strided Layout"""
return self.tstrides[dim].strides[depth]

def all_values(self) -> np.ndarray:
def all_values(self) -> NDArray[np.int_]:
"""
Returns a numpy array containing all the elements in the iteration space.
"""
result = np.array([0])

for _, _, stride in self:
next_stride = np.array(stride.all_values())
next_stride = np.array(stride.all_values(), dtype=np.int_)
# for every stride, add a dimension and broadcast sum
result = np.squeeze(
np.expand_dims(result, -1) + np.expand_dims(next_stride, 0)
Expand Down
14 changes: 8 additions & 6 deletions compiler/parser/tsl_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,23 @@ def _parse_int_or_question(self, context_msg: str = "") -> int | None:
return v
self.raise_error("Expected an integer literal or `?`" + context_msg)

def _parse_step(self) -> list[int]:
def _parse_step(self) -> list[int | None]:
"""
steps ::== `(` steps (`,` steps)* `)`
"""
self._parse_token(Token.Kind.L_PAREN, "Expected opening bracket")
steps: list[int] = []
steps: list[int | None] = []
while not self._parse_optional_token(Token.Kind.R_PAREN):
steps.append(self._parse_int_or_question())
self._parse_optional_token(Token.Kind.COMMA)
return steps

def _parse_bound(self) -> list[int]:
def _parse_bound(self) -> list[int | None]:
"""
bounds ::== `[` bound (`,` bound)* `]`
"""
self._parse_token(Token.Kind.L_SQUARE, "Expected opening bracket")
bounds: list[int] = []
bounds: list[int | None] = []
while not self._parse_optional_token(Token.Kind.R_SQUARE):
bounds.append(self._parse_int_or_question())
self._parse_optional_token(Token.Kind.COMMA)
Expand All @@ -49,15 +49,17 @@ def _parse_tiled_stride(self) -> TiledStride:
self._parse_token(Token.Kind.ARROW, "Expected arrow")
steps = self._parse_step()
if len(steps) != len(bounds):
raise ParseError("Expected same number of steps and bounds")
raise ParseError(
self._current_token.span, "Expected same number of steps and bounds"
)
# construct the tiledstrides
return TiledStride([Stride(step, bound) for step, bound in zip(steps, bounds)])

def parse(self) -> TiledStridedLayout:
"""
tsl ::= tiled-stride (`,` tiled-stride)*` (, offset: ` offset)?
"""
tstrides = []
tstrides: list[TiledStride] = []
offset = 0
while True:
if self._current_token.kind == Token.Kind.GREATER:
Expand Down
2 changes: 1 addition & 1 deletion pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 0 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,6 @@ typeCheckingMode = "strict"
"compiler/inference/helpers.py",
"compiler/inference/scoped_setups.py",
"compiler/inference/trace_acc_state.py",
"compiler/ir/tsl/tiled_strided_layout.py",
"compiler/parser/tsl_parser.py",
"compiler/transforms/accfg_dedup.py",
"compiler/transforms/clear_memory_space.py",
"compiler/transforms/convert_linalg_to_accfg.py",
Expand All @@ -114,11 +112,7 @@ typeCheckingMode = "strict"
"compiler/util/memref_descriptor.py",
"tests/benchmark/test_snax_benchmark.py",
"tests/dialects/test_snax.py",
"tests/dialects/test_tsl.py",
"tests/inference/test_accfg_state_tracing.py",
"tests/ir/tsl/test_stride.py",
"tests/ir/tsl/test_tiled_stride.py",
"tests/ir/tsl/test_tiled_strided_layout.py",
"tests/util/",
]

Expand Down
4 changes: 2 additions & 2 deletions tests/dialects/test_tsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ def example_tsl_attr():
return tsl_attr


def test_tsl_attr_constructor(example_tsl_attr):
def test_tsl_attr_constructor(example_tsl_attr: TiledStridedLayoutAttr):
tsl = example_tsl_attr
assert isinstance(tsl, TiledStridedLayoutAttr)
assert isinstance(tsl.data, TiledStridedLayout)


def test_tsl_attr_get_affine(example_tsl_attr):
def test_tsl_attr_get_affine(example_tsl_attr: TiledStridedLayoutAttr):
tsl = example_tsl_attr
map = canonicalize_map(tsl.get_affine_map())
assert map == canonicalize_map(
Expand Down
6 changes: 3 additions & 3 deletions tests/ir/tsl/test_stride.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def example_strides():
return (Stride(1, 4), Stride(4, 6), Stride(24, 2), Stride(None, None))


def test_stride_constructor(example_strides):
def test_stride_constructor(example_strides: tuple[Stride, ...]):
stride1, stride2, _, dynamic_stride = example_strides
assert stride1.step == 1
assert stride1.bound == 4
Expand All @@ -18,15 +18,15 @@ def test_stride_constructor(example_strides):
assert dynamic_stride.bound is None


def test_stride_all_values(example_strides):
def test_stride_all_values(example_strides: tuple[Stride, ...]):
stride1, stride2, _, dynamic_stride = example_strides
assert stride1.all_values() == [0, 1, 2, 3]
assert stride2.all_values() == [0, 4, 8, 12, 16, 20]
with pytest.raises(ValueError):
dynamic_stride.all_values()


def test_stride_str(example_strides):
def test_stride_str(example_strides: tuple[Stride, ...]):
stride1, stride2, stride3, dynamic_stride = example_strides
assert str(stride1) == "4 -> 1"
assert str(stride2) == "6 -> 4"
Expand Down
16 changes: 10 additions & 6 deletions tests/ir/tsl/test_tiled_stride.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@ def example_strides():


@pytest.fixture()
def example_tiled_strides(example_strides):
def example_tiled_strides(example_strides: tuple[Stride, ...]):
stride1, stride2, stride3, dynamic_stride = example_strides
tiledStride1 = TiledStride([stride2, stride1])
tiledStride2 = TiledStride([stride3, stride2, stride1])
tiledStride3 = TiledStride([dynamic_stride, stride1])
return tiledStride1, tiledStride2, tiledStride3


def test_tiled_stride_constructor(example_strides, example_tiled_strides):
def test_tiled_stride_constructor(
example_strides: tuple[Stride, ...], example_tiled_strides: tuple[TiledStride, ...]
):
stride1, stride2, stride3, _ = example_strides
tiledStride1, tiledStride2, _ = example_tiled_strides
assert tiledStride1.strides[0] == stride2
Expand All @@ -38,21 +40,23 @@ def test_tiled_stride_from_stride():
assert tiledStride2.strides[2] == Stride(24, 4)


def test_tiled_stride_depth(example_tiled_strides):
def test_tiled_stride_depth(example_tiled_strides: tuple[TiledStride, ...]):
tiledStride1, tiledStride2, tiledStride3 = example_tiled_strides
assert tiledStride1.depth() == 2
assert tiledStride2.depth() == 3
assert tiledStride3.depth() == 2


def test_tiled_stride_str(example_tiled_strides):
def test_tiled_stride_str(example_tiled_strides: tuple[TiledStride, ...]):
tiledStride1, tiledStride2, tiledStride3 = example_tiled_strides
assert str(tiledStride1) == "[6, 4] -> (4, 1)"
assert str(tiledStride2) == "[2, 6, 4] -> (24, 4, 1)"
assert str(tiledStride3) == "[?, 4] -> (?, 1)"


def test_tiled_stride_iter(example_strides, example_tiled_strides):
def test_tiled_stride_iter(
example_strides: tuple[Stride, ...], example_tiled_strides: tuple[TiledStride, ...]
):
stride1, stride2, stride3, _ = example_strides
strides = [stride3, stride2, stride1]

Expand All @@ -64,7 +68,7 @@ def test_tiled_stride_iter(example_strides, example_tiled_strides):
assert stride == strides[depth]


def test_tiled_stride_tile_bounds(example_tiled_strides):
def test_tiled_stride_tile_bounds(example_tiled_strides: tuple[TiledStride, ...]):
tiledStride1, tiledStride2, tiledStride3 = example_tiled_strides
assert tiledStride1.tile_bounds() == [6, 4]
assert tiledStride2.tile_bounds() == [2, 6, 4]
Expand Down
18 changes: 9 additions & 9 deletions tests/ir/tsl/test_tiled_strided_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@ def example_tsl():
return tsl, tsl2


def test_tsl_constructor(example_tsl):
def test_tsl_constructor(example_tsl: tuple[TiledStridedLayout, ...]):
tsl, _ = example_tsl
assert isinstance(tsl.tstrides[0], TiledStride)
assert isinstance(tsl.tstrides[1], TiledStride)


def test_tsl_from_strides():
strides = [None, 1]
tile_bounds = [[16, 4], [16, 4]]
tile_bounds: list[list[int | None]] = [[16, 4], [16, 4]]
tsl_constructor = TiledStridedLayout(
[
TiledStride([Stride(None, 16), Stride(None, 4)]),
Expand All @@ -49,13 +49,13 @@ def test_tsl_from_strides():
assert tsl_constructor == tsl_from_strides


def test_tsl_str(example_tsl):
def test_tsl_str(example_tsl: tuple[TiledStridedLayout, ...]):
tsl, tsl2 = example_tsl
assert str(tsl) == "[2, 4] -> (32, 4), [2, 4] -> (16, 1), offset: 5"
assert str(tsl2) == "[2, 4] -> (32, 4), [?, 4] -> (?, 1), offset: 7"


def test_tsl_iter(example_tsl):
def test_tsl_iter(example_tsl: tuple[TiledStridedLayout, ...]):
tsl, _ = example_tsl
count = 0
for dim, depth, stride in tsl:
Expand All @@ -67,19 +67,19 @@ def test_tsl_iter(example_tsl):
assert count == tsl.dimension() * tsl.tstrides[0].depth()


def test_tsl_all_values(example_tsl):
def test_tsl_all_values(example_tsl: tuple[TiledStridedLayout, ...]):
tsl, tsl2 = example_tsl
assert set(tsl.all_values()) == set(range(64))
with pytest.raises(ValueError):
tsl2.all_values()


def test_tsl_tile_bounds(example_tsl):
def test_tsl_tile_bounds(example_tsl: tuple[TiledStridedLayout, ...]):
tsl, _ = example_tsl
assert tsl.tile_bounds() == [[2, 4], [2, 4]]


def test_tsl_self_overlaps(example_tsl):
def test_tsl_self_overlaps(example_tsl: tuple[TiledStridedLayout, ...]):
tsl, _ = example_tsl
assert not tsl.self_overlaps()

Expand All @@ -100,7 +100,7 @@ def test_tsl_self_overlaps(example_tsl):
assert tsl2.self_overlaps()


def test_tsl_is_dense(example_tsl):
def test_tsl_is_dense(example_tsl: tuple[TiledStridedLayout, ...]):
tsl, _ = example_tsl
assert tsl.is_dense()

Expand All @@ -121,7 +121,7 @@ def test_tsl_is_dense(example_tsl):
assert not tsl2.is_dense()


def test_tsl_equal_tile_bounds(example_tsl):
def test_tsl_equal_tile_bounds(example_tsl: tuple[TiledStridedLayout, ...]):
tsl, tsl2 = example_tsl
assert tsl.equal_tile_bounds(tsl)
assert not tsl.equal_tile_bounds(tsl2)
Expand Down

0 comments on commit 480c18b

Please sign in to comment.