diff --git a/compiler/ir/autoflow/affine_transform.py b/compiler/ir/autoflow/affine_transform.py index cab29721..e3dce0d9 100644 --- a/compiler/ir/autoflow/affine_transform.py +++ b/compiler/ir/autoflow/affine_transform.py @@ -127,5 +127,10 @@ def compose(self, other: Self) -> Self: new_b = self.A @ other.b + self.b return type(self)(new_A, new_b) + def __eq__(self, other: object) -> bool: + if not isinstance(other, AffineTransform): + return False + return (self.A == other.A).all() and (self.b == other.b).all() + def __str__(self): return f"AffineTransform(A=\n{self.A},\nb={self.b})" diff --git a/compiler/ir/stream/access_pattern.py b/compiler/ir/stream/access_pattern.py index c067369f..a28d87f6 100644 --- a/compiler/ir/stream/access_pattern.py +++ b/compiler/ir/stream/access_pattern.py @@ -4,9 +4,9 @@ from typing import Generic, cast from typing_extensions import Self, TypeVar, overload -from xdsl.ir.affine import AffineConstantExpr, AffineDimExpr, AffineExpr, AffineMap +from xdsl.ir.affine import AffineDimExpr, AffineMap -from compiler.util.canonicalize_affine import canonicalize_map +from compiler.ir.autoflow.affine_transform import AffineTransform @dataclass(frozen=True) @@ -17,27 +17,26 @@ class AccessPattern(ABC): """ bounds: tuple[int | None, ...] - pattern: AffineMap + pattern: AffineTransform - def __init__(self, bounds: Sequence[int | None], pattern: AffineMap): + def __init__( + self, bounds: Sequence[int | None], pattern: AffineMap | AffineTransform + ): # Convert bounds to a tuple bounds = tuple(bounds) + if isinstance(pattern, AffineMap): + pattern = AffineTransform.from_affine_map(pattern) + # Perform validations if len(bounds) != pattern.num_dims: raise ValueError( "The number of bounds should be equal to the dimension of the pattern" ) - if pattern.num_symbols > 0: - raise ValueError("Symbols in the pattern are not supported") - - # Canonicalize the pattern - new_pattern = canonicalize_map(pattern) - # Assign attributes using object.__setattr__ due to frozen=True object.__setattr__(self, "bounds", bounds) - object.__setattr__(self, "pattern", new_pattern) + object.__setattr__(self, "pattern", pattern) @property def num_dims(self): @@ -54,14 +53,9 @@ def disable_dims(self, dim: int) -> Self: For `dim` = 2, will return: (d2) -> d2 """ - new_pattern = self.pattern.replace_dims_and_symbols( - tuple(AffineConstantExpr(0) for _ in range(dim)) - + tuple(AffineDimExpr(i) for i in range(self.num_dims - dim)), - [], - self.num_dims - dim, - 0, + return type(self)( + self.bounds[dim:], AffineTransform(self.pattern.A[:, dim:], self.pattern.b) ) - return type(self)(self.bounds[dim:], new_pattern) @dataclass(frozen=True) @@ -75,7 +69,7 @@ class SchedulePattern(AccessPattern): # constrain bounds to only be int bounds: tuple[int, ...] - def __init__(self, bounds: Sequence[int], pattern: AffineMap): + def __init__(self, bounds: Sequence[int], pattern: AffineMap | AffineTransform): if any(bound <= 0 for bound in bounds): raise ValueError( "All bounds must be static, strictly positive integers for a schedule" @@ -100,13 +94,10 @@ def rotate(self, dim: int) -> Self: # (0, 1, 2, 3, ..., dim-1, dim, dim+1, ..., num_dims - 1) # --> (1, 2, 3, ..., dim-1, 0, dim, dim+1, ..., num_dims - 1) - new_dims = tuple(AffineDimExpr(i) for i in range(self.num_dims)) - new_dims = new_dims[dim - 1 : dim] + new_dims[: dim - 1] + new_dims[dim:] new_bounds = self.bounds[1:dim] + self.bounds[:1] + self.bounds[dim:] + new_a = self.pattern.A[:, [*range(1, dim), 0, *range(dim, self.num_dims)]] + new_pattern = AffineTransform(new_a, self.pattern.b) - new_pattern = self.pattern.replace_dims_and_symbols( - new_dims, [], self.num_dims, 0 - ) return type(self)(new_bounds, new_pattern) def tile_dim(self, dim: int, template_bound: int) -> Self: @@ -127,15 +118,17 @@ def tile_dim(self, dim: int, template_bound: int) -> Self: [2, 4, 2, 2] """ - transform_map = AffineMap( - num_dims=self.num_dims + 1, - num_symbols=0, - # (d0, d1, d2, ..., dim-1) -> (d0, d1, d2, ..., dim-1) - results=tuple(AffineDimExpr(i) for i in range(dim)) - # (dim) -> (template_bound * dim + dim + 1) - + (AffineDimExpr(dim) * template_bound + AffineDimExpr(dim + 1),) - # (dim + 1, dim + 2, ...) -> (dim + 2, dim + 3, dim + 3) - + tuple(AffineDimExpr(i + 1) for i in range(dim + 1, self.num_dims)), + transform_map = AffineTransform.from_affine_map( + AffineMap( + num_dims=self.num_dims + 1, + num_symbols=0, + # (d0, d1, d2, ..., dim-1) -> (d0, d1, d2, ..., dim-1) + results=tuple(AffineDimExpr(i) for i in range(dim)) + # (dim) -> (template_bound * dim + dim + 1) + + (AffineDimExpr(dim) * template_bound + AffineDimExpr(dim + 1),) + # (dim + 1, dim + 2, ...) -> (dim + 2, dim + 3, dim + 3) + + tuple(AffineDimExpr(i + 1) for i in range(dim + 1, self.num_dims)), + ) ) new_pattern = self.pattern.compose(transform_map) bound_to_tile = self.bounds[dim] @@ -155,10 +148,12 @@ def add_dim(self) -> Self: (d0, d1, d2) -> d1 + d2 """ new_pattern = self.pattern - transform_map = AffineMap( - num_dims=self.num_dims + 1, - num_symbols=0, - results=tuple(AffineDimExpr(i + 1) for i in range(self.num_dims)), + transform_map = AffineTransform.from_affine_map( + AffineMap( + num_dims=self.num_dims + 1, + num_symbols=0, + results=tuple(AffineDimExpr(i + 1) for i in range(self.num_dims)), + ) ) new_pattern = self.pattern.compose(transform_map) new_bounds = (1,) + self.bounds @@ -173,7 +168,9 @@ class TemplatePattern(AccessPattern): Templates should not be transformed through either tiling/rotating/others. """ - def __init__(self, bounds: Sequence[int | None], pattern: AffineMap): + def __init__( + self, bounds: Sequence[int | None], pattern: AffineMap | AffineTransform + ): super().__init__(bounds, pattern) def matches(self, sp: SchedulePattern): @@ -243,21 +240,11 @@ def clear_unused_dims(self, bounds: tuple[int] | None = None) -> Self: pattern_bounds = self._patterns[0].bounds else: pattern_bounds = bounds - unused_dims = tuple(i for i, bound in enumerate(pattern_bounds) if bound == 1) - dim_substitutions: list[AffineExpr] = [] - unused_counter = 0 - for dim in range(self.num_dims): - if dim not in unused_dims: - dim_substitutions.append(AffineDimExpr(dim - unused_counter)) - else: - dim_substitutions.append(AffineConstantExpr(0)) - unused_counter += 1 + used_dims = tuple(i for i, bound in enumerate(pattern_bounds) if bound != 1) return type(self)( type(self._patterns[0])( tuple(bound for bound in pattern_bounds if bound != 1), - sp.pattern.replace_dims_and_symbols( - dim_substitutions, [], self.num_dims - unused_counter, 0 - ), + AffineTransform(sp.pattern.A[:, used_dims], sp.pattern.b), ) for sp in self ) diff --git a/compiler/transforms/convert_stream_to_snax_stream.py b/compiler/transforms/convert_stream_to_snax_stream.py index e5dbb8db..4d7913ba 100644 --- a/compiler/transforms/convert_stream_to_snax_stream.py +++ b/compiler/transforms/convert_stream_to_snax_stream.py @@ -96,7 +96,7 @@ def generate_one_list(n: int, i: int): data_mem_map: AffineMap = memref_type.get_affine_map_in_bytes() # Mapping from access to data: - access_data_map: AffineMap = schedule[operand].pattern + access_data_map: AffineMap = schedule[operand].pattern.to_affine_map() # Mapping from access to memory: access_mem_map: AffineMap = data_mem_map.compose(access_data_map) diff --git a/tests/ir/stream/test_access_pattern.py b/tests/ir/stream/test_access_pattern.py index eb3dc7a4..dd9ce533 100644 --- a/tests/ir/stream/test_access_pattern.py +++ b/tests/ir/stream/test_access_pattern.py @@ -1,6 +1,8 @@ +import numpy as np import pytest -from xdsl.ir.affine import AffineConstantExpr, AffineDimExpr, AffineMap +from xdsl.ir.affine import AffineDimExpr, AffineMap +from compiler.ir.autoflow import AffineTransform from compiler.ir.stream import ( AccessPattern, Schedule, @@ -20,60 +22,53 @@ def test_access_pattern_creation(): bounds = (10, 20, 30) access_pattern = AccessPattern(bounds, pattern) assert access_pattern.bounds == bounds - assert access_pattern.pattern == pattern + assert access_pattern.pattern == AffineTransform.from_affine_map(pattern) assert access_pattern.num_dims == 3 def test_access_pattern_disable_dims(): - pattern = AffineMap( - num_dims=3, - num_symbols=0, - results=(AffineDimExpr(0), AffineDimExpr(1), AffineDimExpr(2)), + pattern = AffineTransform( + np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), b=np.array([1, 2, 3]) ) bounds = (10, 20, 30) access_pattern = AccessPattern(bounds, pattern) # test 1: disable 0 dims (none) disabled_pattern = access_pattern.disable_dims(0) - expected_bounds = (10, 20, 30) - expected_results = (AffineDimExpr(0), AffineDimExpr(1), AffineDimExpr(2)) - assert disabled_pattern.bounds == expected_bounds - assert disabled_pattern.pattern.results == expected_results + assert disabled_pattern.bounds == bounds + assert disabled_pattern.pattern == pattern assert isinstance(disabled_pattern, AccessPattern) # test 2: disable 1 dims disabled_pattern = access_pattern.disable_dims(1) expected_bounds = (20, 30) - expected_results = (AffineConstantExpr(0), AffineDimExpr(0), AffineDimExpr(1)) + expected_results = np.array([[0, 0], [1, 0], [0, 1]]) assert disabled_pattern.bounds == expected_bounds - assert disabled_pattern.pattern.results == expected_results + assert (disabled_pattern.pattern.A == expected_results).all() + assert (disabled_pattern.pattern.b == pattern.b).all() assert isinstance(disabled_pattern, AccessPattern) # test 3: disable 2 dims disabled_pattern = access_pattern.disable_dims(2) expected_bounds = (30,) - expected_results = (AffineConstantExpr(0), AffineConstantExpr(0), AffineDimExpr(0)) + expected_results = np.array([[0], [0], [1]]) assert disabled_pattern.bounds == expected_bounds - assert disabled_pattern.pattern.results == expected_results + assert (disabled_pattern.pattern.A == expected_results).all() + assert (disabled_pattern.pattern.b == pattern.b).all() assert isinstance(disabled_pattern, AccessPattern) # test 4: disable 3 dims (all) disabled_pattern = access_pattern.disable_dims(3) expected_bounds: tuple[int, ...] = tuple() - expected_results = ( - AffineConstantExpr(0), - AffineConstantExpr(0), - AffineConstantExpr(0), - ) + expected_results = [] assert disabled_pattern.bounds == expected_bounds - assert disabled_pattern.pattern.results == expected_results + assert (disabled_pattern.pattern.A == expected_results).all() + assert (disabled_pattern.pattern.b == pattern.b).all() assert isinstance(disabled_pattern, AccessPattern) def test_schedule_pattern_creation(): - pattern = AffineMap( - num_dims=2, num_symbols=0, results=(AffineDimExpr(0), AffineDimExpr(1)) - ) + pattern = AffineTransform(np.array([[1, 0], [0, 1]]), np.array([0, 0])) bounds = (15, 25) schedule_pattern = SchedulePattern(bounds, pattern) assert schedule_pattern.bounds == bounds @@ -84,9 +79,7 @@ def test_schedule_pattern_creation(): def test_schedule_pattern_invalid_bounds(): - pattern = AffineMap( - num_dims=2, num_symbols=0, results=(AffineDimExpr(0), AffineDimExpr(1)) - ) + pattern = AffineTransform(np.array([[1, 0], [0, 1]]), np.array([0, 0])) with pytest.raises( ValueError, match="All bounds must be static, strictly positive integers for a schedule", @@ -95,10 +88,8 @@ def test_schedule_pattern_invalid_bounds(): def test_schedule_pattern_rotate(): - pattern = AffineMap( - num_dims=3, - num_symbols=0, - results=(AffineDimExpr(0), AffineDimExpr(1), AffineDimExpr(2)), + pattern = AffineTransform( + np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), np.array([0, 0, 0]) ) bounds = (10, 20, 30) access_pattern = SchedulePattern(bounds, pattern) @@ -106,37 +97,33 @@ def test_schedule_pattern_rotate(): # test 1: 3 dims, rotate 2 rotated_pattern = access_pattern.rotate(2) expected_bounds = (20, 10, 30) - expected_results = (AffineDimExpr(1), AffineDimExpr(0), AffineDimExpr(2)) + expected_results = np.array([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) assert rotated_pattern.bounds == expected_bounds - assert rotated_pattern.pattern.results == expected_results + assert (rotated_pattern.pattern.A == expected_results).all() + assert (rotated_pattern.pattern.b == pattern.b).all() assert isinstance(rotated_pattern, AccessPattern) # test 2: 3 dims, rotate 3 rotated_pattern = access_pattern.rotate(3) expected_bounds = (20, 30, 10) - expected_results = (AffineDimExpr(2), AffineDimExpr(0), AffineDimExpr(1)) + expected_results = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) assert rotated_pattern.bounds == expected_bounds - assert rotated_pattern.pattern.results == expected_results + assert (rotated_pattern.pattern.A == expected_results).all() + assert (rotated_pattern.pattern.b == pattern.b).all() assert isinstance(rotated_pattern, AccessPattern) # test 3: 3 dims, rotate 1 rotated_pattern = access_pattern.rotate(1) expected_bounds = (10, 20, 30) - expected_results = (AffineDimExpr(0), AffineDimExpr(1), AffineDimExpr(2)) assert rotated_pattern.bounds == expected_bounds - assert rotated_pattern.pattern.results == expected_results + assert (rotated_pattern.pattern.A == pattern.A).all() + assert (rotated_pattern.pattern.b == pattern.b).all() assert isinstance(rotated_pattern, AccessPattern) # test 4 dims - pattern = AffineMap( - num_dims=4, - num_symbols=0, - results=( - AffineDimExpr(0), - AffineDimExpr(1), - AffineDimExpr(2), - AffineDimExpr(3), - ), + pattern = AffineTransform( + np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]), + np.array([0, 0, 0, 0]), ) bounds = (10, 20, 30, 40) access_pattern = SchedulePattern(bounds, pattern) @@ -144,61 +131,45 @@ def test_schedule_pattern_rotate(): # test 4: 4 dims, rotate 3 rotated_pattern = access_pattern.rotate(3) expected_bounds = (20, 30, 10, 40) - expected_results = ( - AffineDimExpr(2), - AffineDimExpr(0), - AffineDimExpr(1), - AffineDimExpr(3), + expected_results = np.array( + [[0, 0, 1, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]] ) assert rotated_pattern.bounds == expected_bounds - assert rotated_pattern.pattern.results == expected_results + assert (rotated_pattern.pattern.A == expected_results).all() + assert (rotated_pattern.pattern.b == pattern.b).all() assert isinstance(rotated_pattern, AccessPattern) def test_schedule_pattern_add_dim(): - pattern = AffineMap( - num_dims=3, - num_symbols=0, - results=(AffineDimExpr(0), AffineDimExpr(1), AffineDimExpr(2)), + pattern = AffineTransform( + np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), np.array([0, 0, 0]) ) bounds = (10, 20, 30) access_pattern = SchedulePattern(bounds, pattern) pattern_new_dim = access_pattern.add_dim() expected_bounds = (1, 10, 20, 30) - expected_results = ( - AffineDimExpr(1), - AffineDimExpr(2), - AffineDimExpr(3), - ) + expected_results = np.array([[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) assert pattern_new_dim.bounds == expected_bounds - assert pattern_new_dim.pattern.results == expected_results + assert (pattern_new_dim.pattern.A == expected_results).all() assert isinstance(pattern_new_dim, SchedulePattern) def test_schedule_pattern_tile_dim(): - pattern = AffineMap( - num_dims=3, - num_symbols=0, - results=(AffineDimExpr(0), AffineDimExpr(1), AffineDimExpr(2)), + pattern = AffineTransform( + np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), np.array([0, 0, 0]) ) bounds = (10, 20, 30) access_pattern = SchedulePattern(bounds, pattern) tiled_pattern = access_pattern.tile_dim(1, 5) expected_bounds = (10, 4, 5, 30) - expected_results = ( - AffineDimExpr(0), - AffineDimExpr(1) * 5 + AffineDimExpr(2), - AffineDimExpr(3), - ) + expected_results = np.array([[1, 0, 0, 0], [0, 5, 1, 0], [0, 0, 0, 1]]) assert tiled_pattern.bounds == expected_bounds - assert tiled_pattern.pattern.results == expected_results + assert (tiled_pattern.pattern.A == expected_results).all() assert isinstance(tiled_pattern, SchedulePattern) def test_template_pattern_creation(): - pattern = AffineMap( - num_dims=2, num_symbols=0, results=(AffineDimExpr(0), AffineDimExpr(1)) - ) + pattern = AffineTransform(np.array([[1, 0], [0, 1], [0, 1]]), np.array([0, 0, 0])) bounds = (5, 10) template_pattern = TemplatePattern(bounds, pattern) assert template_pattern.bounds == bounds @@ -270,19 +241,15 @@ def test_schedule_tile_dim(): def test_schedule_clear_unused_dims(): - pattern1 = AffineMap( - num_dims=3, num_symbols=0, results=(AffineDimExpr(0), AffineDimExpr(1)) - ) + pattern1 = AffineTransform(np.array([[1, 0, 0], [0, 1, 0]]), np.array([0, 0])) sp1 = SchedulePattern((1, 10, 1), pattern1) schedule = Schedule([sp1]) cleared_schedule = schedule.clear_unused_dims() assert isinstance(cleared_schedule, Schedule) assert cleared_schedule[0].bounds == ((10,)) - assert cleared_schedule[0].pattern.results == ( - AffineConstantExpr(0), - AffineDimExpr(0), - ) + expected_results = np.array([[0], [1]]) + assert (cleared_schedule[0].pattern.A == expected_results).all() def test_template_disable_dims():