Skip to content

Commit

Permalink
AccessPattern utils: change to AffineTransform representations (#325)
Browse files Browse the repository at this point in the history
* change to affinetransform

first unit tests working

rebase fixes

* fix convert_stream_to_snax_stream

* format

* fix pyright
  • Loading branch information
jorendumoulin authored Jan 6, 2025
1 parent d663036 commit 8e90f61
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 132 deletions.
5 changes: 5 additions & 0 deletions compiler/ir/autoflow/affine_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,5 +127,10 @@ def compose(self, other: Self) -> Self:
new_b = self.A @ other.b + self.b
return type(self)(new_A, new_b)

def __eq__(self, other: object) -> bool:
if not isinstance(other, AffineTransform):
return False
return (self.A == other.A).all() and (self.b == other.b).all()

def __str__(self):
return f"AffineTransform(A=\n{self.A},\nb={self.b})"
87 changes: 37 additions & 50 deletions compiler/ir/stream/access_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from typing import Generic, cast

from typing_extensions import Self, TypeVar, overload
from xdsl.ir.affine import AffineConstantExpr, AffineDimExpr, AffineExpr, AffineMap
from xdsl.ir.affine import AffineDimExpr, AffineMap

from compiler.util.canonicalize_affine import canonicalize_map
from compiler.ir.autoflow.affine_transform import AffineTransform


@dataclass(frozen=True)
Expand All @@ -17,27 +17,26 @@ class AccessPattern(ABC):
"""

bounds: tuple[int | None, ...]
pattern: AffineMap
pattern: AffineTransform

def __init__(self, bounds: Sequence[int | None], pattern: AffineMap):
def __init__(
self, bounds: Sequence[int | None], pattern: AffineMap | AffineTransform
):
# Convert bounds to a tuple
bounds = tuple(bounds)

if isinstance(pattern, AffineMap):
pattern = AffineTransform.from_affine_map(pattern)

# Perform validations
if len(bounds) != pattern.num_dims:
raise ValueError(
"The number of bounds should be equal to the dimension of the pattern"
)

if pattern.num_symbols > 0:
raise ValueError("Symbols in the pattern are not supported")

# Canonicalize the pattern
new_pattern = canonicalize_map(pattern)

# Assign attributes using object.__setattr__ due to frozen=True
object.__setattr__(self, "bounds", bounds)
object.__setattr__(self, "pattern", new_pattern)
object.__setattr__(self, "pattern", pattern)

@property
def num_dims(self):
Expand All @@ -54,14 +53,9 @@ def disable_dims(self, dim: int) -> Self:
For `dim` = 2, will return:
(d2) -> d2
"""
new_pattern = self.pattern.replace_dims_and_symbols(
tuple(AffineConstantExpr(0) for _ in range(dim))
+ tuple(AffineDimExpr(i) for i in range(self.num_dims - dim)),
[],
self.num_dims - dim,
0,
return type(self)(
self.bounds[dim:], AffineTransform(self.pattern.A[:, dim:], self.pattern.b)
)
return type(self)(self.bounds[dim:], new_pattern)


@dataclass(frozen=True)
Expand All @@ -75,7 +69,7 @@ class SchedulePattern(AccessPattern):
# constrain bounds to only be int
bounds: tuple[int, ...]

def __init__(self, bounds: Sequence[int], pattern: AffineMap):
def __init__(self, bounds: Sequence[int], pattern: AffineMap | AffineTransform):
if any(bound <= 0 for bound in bounds):
raise ValueError(
"All bounds must be static, strictly positive integers for a schedule"
Expand All @@ -100,13 +94,10 @@ def rotate(self, dim: int) -> Self:
# (0, 1, 2, 3, ..., dim-1, dim, dim+1, ..., num_dims - 1)
# --> (1, 2, 3, ..., dim-1, 0, dim, dim+1, ..., num_dims - 1)

new_dims = tuple(AffineDimExpr(i) for i in range(self.num_dims))
new_dims = new_dims[dim - 1 : dim] + new_dims[: dim - 1] + new_dims[dim:]
new_bounds = self.bounds[1:dim] + self.bounds[:1] + self.bounds[dim:]
new_a = self.pattern.A[:, [*range(1, dim), 0, *range(dim, self.num_dims)]]
new_pattern = AffineTransform(new_a, self.pattern.b)

new_pattern = self.pattern.replace_dims_and_symbols(
new_dims, [], self.num_dims, 0
)
return type(self)(new_bounds, new_pattern)

def tile_dim(self, dim: int, template_bound: int) -> Self:
Expand All @@ -127,15 +118,17 @@ def tile_dim(self, dim: int, template_bound: int) -> Self:
[2, 4, 2, 2]
"""
transform_map = AffineMap(
num_dims=self.num_dims + 1,
num_symbols=0,
# (d0, d1, d2, ..., dim-1) -> (d0, d1, d2, ..., dim-1)
results=tuple(AffineDimExpr(i) for i in range(dim))
# (dim) -> (template_bound * dim + dim + 1)
+ (AffineDimExpr(dim) * template_bound + AffineDimExpr(dim + 1),)
# (dim + 1, dim + 2, ...) -> (dim + 2, dim + 3, dim + 3)
+ tuple(AffineDimExpr(i + 1) for i in range(dim + 1, self.num_dims)),
transform_map = AffineTransform.from_affine_map(
AffineMap(
num_dims=self.num_dims + 1,
num_symbols=0,
# (d0, d1, d2, ..., dim-1) -> (d0, d1, d2, ..., dim-1)
results=tuple(AffineDimExpr(i) for i in range(dim))
# (dim) -> (template_bound * dim + dim + 1)
+ (AffineDimExpr(dim) * template_bound + AffineDimExpr(dim + 1),)
# (dim + 1, dim + 2, ...) -> (dim + 2, dim + 3, dim + 3)
+ tuple(AffineDimExpr(i + 1) for i in range(dim + 1, self.num_dims)),
)
)
new_pattern = self.pattern.compose(transform_map)
bound_to_tile = self.bounds[dim]
Expand All @@ -155,10 +148,12 @@ def add_dim(self) -> Self:
(d0, d1, d2) -> d1 + d2
"""
new_pattern = self.pattern
transform_map = AffineMap(
num_dims=self.num_dims + 1,
num_symbols=0,
results=tuple(AffineDimExpr(i + 1) for i in range(self.num_dims)),
transform_map = AffineTransform.from_affine_map(
AffineMap(
num_dims=self.num_dims + 1,
num_symbols=0,
results=tuple(AffineDimExpr(i + 1) for i in range(self.num_dims)),
)
)
new_pattern = self.pattern.compose(transform_map)
new_bounds = (1,) + self.bounds
Expand All @@ -173,7 +168,9 @@ class TemplatePattern(AccessPattern):
Templates should not be transformed through either tiling/rotating/others.
"""

def __init__(self, bounds: Sequence[int | None], pattern: AffineMap):
def __init__(
self, bounds: Sequence[int | None], pattern: AffineMap | AffineTransform
):
super().__init__(bounds, pattern)

def matches(self, sp: SchedulePattern):
Expand Down Expand Up @@ -243,21 +240,11 @@ def clear_unused_dims(self, bounds: tuple[int] | None = None) -> Self:
pattern_bounds = self._patterns[0].bounds
else:
pattern_bounds = bounds
unused_dims = tuple(i for i, bound in enumerate(pattern_bounds) if bound == 1)
dim_substitutions: list[AffineExpr] = []
unused_counter = 0
for dim in range(self.num_dims):
if dim not in unused_dims:
dim_substitutions.append(AffineDimExpr(dim - unused_counter))
else:
dim_substitutions.append(AffineConstantExpr(0))
unused_counter += 1
used_dims = tuple(i for i, bound in enumerate(pattern_bounds) if bound != 1)
return type(self)(
type(self._patterns[0])(
tuple(bound for bound in pattern_bounds if bound != 1),
sp.pattern.replace_dims_and_symbols(
dim_substitutions, [], self.num_dims - unused_counter, 0
),
AffineTransform(sp.pattern.A[:, used_dims], sp.pattern.b),
)
for sp in self
)
Expand Down
2 changes: 1 addition & 1 deletion compiler/transforms/convert_stream_to_snax_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def generate_one_list(n: int, i: int):
data_mem_map: AffineMap = memref_type.get_affine_map_in_bytes()

# Mapping from access to data:
access_data_map: AffineMap = schedule[operand].pattern
access_data_map: AffineMap = schedule[operand].pattern.to_affine_map()

# Mapping from access to memory:
access_mem_map: AffineMap = data_mem_map.compose(access_data_map)
Expand Down
Loading

0 comments on commit 8e90f61

Please sign in to comment.