Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove the raising to high level operator within Unify Axis #565

Open
wants to merge 32 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
afc3e9e
Move the AxesEquationCollector to not use the raising.py operations.
nkoskelo Nov 4, 2024
574ab23
Remove the usage check.
nkoskelo Nov 5, 2024
e02a800
Look at the variables directly and as long as we are using the reserv…
nkoskelo Nov 5, 2024
0948843
Correct ruff suggestions.
nkoskelo Nov 5, 2024
953a643
Only record usage if the array is indexed in some way.
nkoskelo Nov 6, 2024
88eac82
Add a unit test case which is unbroadcastable but is still a legal py…
nkoskelo Nov 6, 2024
52dc445
Add a unit test and split out a reserved pattern for the reductions a…
nkoskelo Nov 7, 2024
1031868
Fix ruff suggestions.
nkoskelo Nov 7, 2024
85e5395
More ruff suggestions.
nkoskelo Nov 7, 2024
a50d58e
Make sure that we return a value if we need to. :)
nkoskelo Nov 7, 2024
921b55f
Working on mypy errors.
nkoskelo Nov 25, 2024
e858110
Respond to comments.
nkoskelo Dec 11, 2024
a04374d
Merge branch 'main' into remove-raising-revived
nkoskelo Dec 11, 2024
17df871
Update for ruff.
nkoskelo Dec 11, 2024
5b01c24
Move typing information to only import if type checking.
nkoskelo Dec 11, 2024
882312f
More ruff CI.
nkoskelo Dec 12, 2024
0feea14
Add noqa: RUF052 for kernels in test_codegen.py.
nkoskelo Dec 12, 2024
8848fd5
Add assert statements for typing purposes.
nkoskelo Dec 12, 2024
a30b3bc
Reorganize. Ruff was out of date. :)
nkoskelo Dec 12, 2024
07bc5ab
Fix some of the mypy errors.
nkoskelo Dec 12, 2024
6e73ed6
Add a mapper for applying the updates in the case of a reduction oper…
nkoskelo Dec 20, 2024
408cdd5
Merge branch 'main' into remove-raising-revived
nkoskelo Dec 20, 2024
0336c00
Fix the ruff comments.
nkoskelo Dec 20, 2024
5242f3f
Update the test code to use the correct name of the argument of its i…
nkoskelo Dec 20, 2024
b14d238
Merge branch 'main' into remove-raising-revived
nkoskelo Jan 8, 2025
e7e750a
Add a test case for the pattern match on binding names. Use the reduc…
nkoskelo Jan 8, 2025
a7827b7
Merge branch 'main' into remove-raising-revived
nkoskelo Jan 9, 2025
23ffe34
Let's get the reduction descriptors at the start when we are recordin…
nkoskelo Jan 14, 2025
9ac0a7d
Update the record equations keys to be [Array, int | str] so that we …
nkoskelo Jan 14, 2025
5ccc11f
Remove unused code.
nkoskelo Jan 14, 2025
b522193
Save off changes moving down to just handling IndexLambdas and those …
nkoskelo Jan 24, 2025
d967d5d
Preconvert to using IndexLambdas. So need to get all of the cases han…
nkoskelo Jan 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pytato/scalar_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,9 @@ def map_reduce(self, expr: Reduce) -> ScalarExpression:
for name, bound in expr.bounds.items()}))


IDX_LAMBDA_RE = re.compile(r"_r?(0|([1-9][0-9]*))")
IDX_LAMBDA_RE = re.compile(r"^(_r?(0|([1-9][0-9]*)))$")
IDX_LAMBDA_INAME = re.compile(r"^(_(0|([1-9][0-9]*)))$")
IDX_LAMBDA_JUST_REDUCTIONS = re.compile(r"^(_r(0|([1-9][0-9]*)))$")


class DependencyMapper(DependencyMapperBase[P]):
Expand Down
232 changes: 147 additions & 85 deletions pytato/transform/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,25 +38,28 @@
THE SOFTWARE.
"""


import logging
import re
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
Any,
ParamSpec,
TypeAlias,
TypeVar,
cast,
)

from bidict import bidict

import pymbolic.primitives as prim
from pytools import UniqueNameGenerator
from pytools.tag import Tag

from pytato.array import (
AbstractResultWithNamedArrays,
AdvancedIndexInContiguousAxes,
Array,
ArrayOrScalar,
AxisPermutation,
BasicIndex,
Concatenate,
Expand All @@ -70,18 +73,14 @@
Reshape,
Stack,
)
from pytato.diagnostic import UnknownIndexLambdaExpr
from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder
from pytato.raising import (
BinaryOp,
BroadcastOp,
C99CallOp,
FullOp,
ReduceOp,
WhereOp,
index_lambda_to_high_level_op,
from pytato.function import NamedCallResult
from pytato.scalar_expr import (
IDX_LAMBDA_INAME,
IDX_LAMBDA_JUST_REDUCTIONS,
CombineMapper,
IdentityMapper as ScalarIdentityMapper,
)
from pytato.scalar_expr import SCALAR_CLASSES
from pytato.transform import ArrayOrNames, CopyMapper, Mapper
from pytato.utils import are_shape_components_equal, are_shapes_equal

Expand All @@ -90,14 +89,112 @@


if TYPE_CHECKING:
from collections.abc import Collection, Mapping
from collections.abc import Collection, Iterable, Mapping

from pytato.function import NamedCallResult
from pytato.loopy import LoopyCall


GraphNodeT = TypeVar("GraphNodeT")

BindingName: TypeAlias = str
P = ParamSpec("P")

IGNORE_VARIABLE_STR: str = "IGNORE"
BINDING_NAME_RESERVED_PATTERN = re.compile(r"^(_in?(0|([1-9][0-9]*)))$")
nkoskelo marked this conversation as resolved.
Show resolved Hide resolved

# {{{ BindingSubscriptsUsedInIndexLambda


class BindingSubscriptsUsedInIndexLambda(CombineMapper[dict[BindingName,
nkoskelo marked this conversation as resolved.
Show resolved Hide resolved
set[tuple[prim.Variable, ...]]],
[]]):
"""
Return all the subscript expressions used by a variable specified by BindingName.

Ex:
_in1[_0,_1] would result in an dictionary entry {"_in1": ("_0", "_1")}.

In the case that a subscript expression is not a variable, like in

Ex:
_in1[_0, 0]

that subscript will be replaced with a `Variable` with the name IGNORE.

So the second example would result in an dictionary entry
{"_in1": ("_0","IGNORE")}.
"""
def combine(self,
values: Iterable[dict[BindingName,
set[tuple[prim.Variable, ...]]]]) \
-> dict[BindingName, set[tuple[prim.Variable, ...]]]:
out: dict[BindingName, set[tuple[prim.Variable, ...]]] = {}
from functools import reduce
return reduce(lambda x, y: x | y, values, out)
nkoskelo marked this conversation as resolved.
Show resolved Hide resolved

def map_subscript(self, expr: prim.Subscript) -> dict[BindingName,
set[tuple[prim.Variable, ...]]]:
"""
Record the indexing expression if the Subscript expression has a prim.Variable
as its aggregate. This will record an ignorable variable for each part of the
indexing expression that is not already a prim.Variable.
"""

if isinstance(expr.aggregate, prim.Variable):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing a recursion. Subscripts can contain subscripts.

name: BindingName = expr.aggregate.name

index = tuple(val if isinstance(val, prim.Variable)
else prim.Variable(name=IGNORE_VARIABLE_STR)
for val in expr.index_tuple)
base: dict[BindingName, set[tuple[prim.Variable, ...]]] = {name:
{index}}
assert base
return base
return {}

def map_algebraic_leaf(self, expr: prim.ExpressionNode) -> dict[BindingName,
set[tuple[prim.Variable, ...]]]:

return {}

def map_constant(self, expr: object) -> dict[BindingName,
set[tuple[prim.Variable, ...]]]:
return {}
# }}}

# {{{ Tag Reduction expressions


@dataclass(init=True, repr=True)
class TagReductionAxesMapper(ScalarIdentityMapper[[]]):
nkoskelo marked this conversation as resolved.
Show resolved Hide resolved

axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]]
array_expr: IndexLambda

def map_subscript(self,
expr: prim.Subscript,
) -> prim.Expression:

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to return the subscript unchanged, which means this should probably be a different type of mapper.

if isinstance(expr.aggregate, prim.Variable):
name: BindingName = expr.aggregate.name

for iaxis, val in enumerate(expr.index_tuple):
if isinstance(val, prim.Variable):
if val.name in self.array_expr.var_to_reduction_descr.keys() and \
name in self.array_expr.bindings.keys():
# We matched the reduction axis.
# Now we need to add the tag to the original expression.
my_key = (self.array_expr.bindings[name], iaxis)
self.array_expr = self.array_expr.with_tagged_reduction(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if you see more than one?

name, self.axis_to_tags.get(my_key, [])
)

assert isinstance(self.array_expr, IndexLambda)

return super().map_subscript(expr)
# }}}


# {{{ AxesTagsEquationCollector

Expand Down Expand Up @@ -236,66 +333,32 @@ def map_index_lambda(self, expr: IndexLambda) -> None:
for bnd in expr.bindings.values():
self.rec(bnd)

try:
hlo = index_lambda_to_high_level_op(expr)
except UnknownIndexLambdaExpr:
from warnings import warn
warn(f"'{expr}' is an unknown index lambda type"
" no tags were propagated across it.", stacklevel=1)
# no propagation semantics implemented for such cases
return

if isinstance(hlo, BinaryOp):
subexprs: tuple[ArrayOrScalar, ...] = (hlo.x1, hlo.x2)
elif isinstance(hlo, WhereOp):
subexprs = (hlo.condition, hlo.then, hlo.else_)
elif isinstance(hlo, FullOp):
# A full-op does not impose any equations
subexprs = ()
elif isinstance(hlo, BroadcastOp):
subexprs = (hlo.x,)
elif isinstance(hlo, C99CallOp):
subexprs = hlo.args
elif isinstance(hlo, ReduceOp):

# {{{ ReduceOp doesn't quite involve broadcasting

i_out_axis = 0
for i_in_axis in range(hlo.x.ndim):
if i_in_axis not in hlo.axes:
self.record_equation(
self.get_var_for_axis(hlo.x, i_in_axis),
self.get_var_for_axis(expr, i_out_axis)
)
i_out_axis += 1

assert i_out_axis == expr.ndim

# }}}

return

else:
raise NotImplementedError(type(hlo))

for subexpr in subexprs:
if isinstance(subexpr, Array):
for i_in_axis, i_out_axis in zip(
range(subexpr.ndim),
range(expr.ndim-subexpr.ndim, expr.ndim),
strict=True):
in_dim = subexpr.shape[i_in_axis]
out_dim = expr.shape[i_out_axis]
if are_shape_components_equal(in_dim, out_dim):
self.record_equation(
self.get_var_for_axis(subexpr, i_in_axis),
self.get_var_for_axis(expr, i_out_axis)
)
index_expr_used = BindingSubscriptsUsedInIndexLambda()(expr.expr)

assert len(expr.shape) == expr.ndim
nkoskelo marked this conversation as resolved.
Show resolved Hide resolved

for vname, set_of_ind_tuple in index_expr_used.items():
for ind_tuple in set_of_ind_tuple:
for axis_ind in range(len(ind_tuple)):
var_ind_name = ind_tuple[axis_ind]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use enumerate.

if IDX_LAMBDA_JUST_REDUCTIONS.fullmatch(var_ind_name.name):
nkoskelo marked this conversation as resolved.
Show resolved Hide resolved
# Reduction axis. We can ignore it.
pass
elif BINDING_NAME_RESERVED_PATTERN.fullmatch(var_ind_name.name):
# Variable name axis.
pass
elif var_ind_name.name == IGNORE_VARIABLE_STR:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if the user has a variable named IGNORE?

# This is not directly represented in output axes. Ignore.
pass
elif IDX_LAMBDA_INAME.fullmatch(var_ind_name.name):
# matched with an iname.
inum = int(var_ind_name.name[1:])
lhs: str = self.get_var_for_axis(expr.bindings[vname], axis_ind)
rhs: str = self.get_var_for_axis(expr, inum)
self.record_equation(lhs, rhs)
else:
# i_in_axis is broadcasted => do not propagate
assert are_shape_components_equal(in_dim, 1)
else:
assert isinstance(subexpr, SCALAR_CLASSES)
raise ValueError(f"Unknown index name used in {vname}")
return
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return


def map_stack(self, expr: Stack) -> None:
"""
Expand Down Expand Up @@ -632,19 +695,18 @@ def rec(self, expr: ArrayOrNames) -> Any:
self.axis_to_tags.get((arg, iaxis), [])
)

if isinstance(expr, IndexLambda):
assert isinstance(expr_copy, IndexLambda)
try:
hlo = index_lambda_to_high_level_op(expr)
except UnknownIndexLambdaExpr:
pass
else:
if isinstance(hlo, ReduceOp):
for iaxis, redn_var in hlo.axes.items():
expr_copy = expr_copy.with_tagged_reduction(
redn_var,
self.axis_to_tags.get((hlo.x, iaxis), [])
)
if isinstance(expr_copy, IndexLambda):
nkoskelo marked this conversation as resolved.
Show resolved Hide resolved
if expr_copy.var_to_reduction_descr:
# This is a reduction operation.
# We need to find the axes that are reduced over
# and update the tag/tag them appropriately.
mymapper: TagReductionAxesMapper = \
TagReductionAxesMapper(
self.axis_to_tags,
expr_copy
)
mymapper(expr_copy.expr) # Tag the axes
expr_copy = mymapper.array_expr # Recover it.
nkoskelo marked this conversation as resolved.
Show resolved Hide resolved

# }}}

Expand Down
14 changes: 7 additions & 7 deletions test/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,11 +916,11 @@ def test_einsum_with_parameterized_shapes(ctx_factory):
m_in = np.random.randint(2, 20)
n_in = np.random.randint(2, 20)

def _get_a_shape(_m, _n):
return (2*_m+1, 3*_n+7)
def _get_a_shape(m_, n_):
return (2*m_+1, 3*n_+7)

def _get_x_shape(_m, _n):
return (3*_n+7, )
def _get_x_shape(_m, n_):
return (3*n_+7, )

A_in = np.random.rand(*_get_a_shape(m_in, n_in)) # noqa: N806
x_in = np.random.rand(*_get_x_shape(m_in, n_in))
Expand Down Expand Up @@ -1570,16 +1570,16 @@ def test_regression_reduction_in_conditional(ctx_factory):
ctx = ctx_factory()
cq = cl.CommandQueue(ctx)

def kernel(usr_np, _pt_data_9):
pt_tmp_53 = _pt_data_9 @ _pt_data_9
def kernel(usr_np, pt_data_9):
pt_tmp_53 = pt_data_9 @ pt_data_9
pt_tmp_42 = usr_np.maximum(pt_tmp_53, pt_tmp_53)
pt_tmp_27 = usr_np.sum(pt_tmp_42)
pt_tmp_0 = usr_np.maximum(pt_tmp_27, pt_tmp_53)
return pt_tmp_0

def get_np_input_args():
return {
"_pt_data_9": np.ones((2, 2)),
"pt_data_9": np.ones((2, 2)),
}

np_inputs = get_np_input_args()
Expand Down
Loading
Loading