From afc3e9e68f70b2777c3d39ba9a3b5e99cffafc1d Mon Sep 17 00:00:00 2001 From: Nick Date: Mon, 4 Nov 2024 15:41:45 -0600 Subject: [PATCH 01/28] Move the AxesEquationCollector to not use the raising.py operations. --- pytato/transform/metadata.py | 173 ++++++++++++++++++++--------------- 1 file changed, 98 insertions(+), 75 deletions(-) diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 139b7bf5b..31c2a4a1a 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -57,7 +57,6 @@ AbstractResultWithNamedArrays, AdvancedIndexInContiguousAxes, Array, - ArrayOrScalar, AxisPermutation, BasicIndex, Concatenate, @@ -71,19 +70,8 @@ Reshape, Stack, ) -from pytato.diagnostic import UnknownIndexLambdaExpr from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder from pytato.function import NamedCallResult -from pytato.raising import ( - BinaryOp, - BroadcastOp, - C99CallOp, - FullOp, - ReduceOp, - WhereOp, - index_lambda_to_high_level_op, -) -from pytato.scalar_expr import SCALAR_CLASSES from pytato.transform import ArrayOrNames, CopyMapper, Mapper from pytato.utils import are_shape_components_equal, are_shapes_equal @@ -97,6 +85,44 @@ GraphNodeT = TypeVar("GraphNodeT") +import pymbolic.primitives as prim + +from pytato.scalar_expr import ( + IdentityMapper as ScalarMap, +) + + +class AxesUsedMapper(ScalarMap): + """ + Determine which axes are used in the scalar expression and which ones just + flow through the expression. + """ + + def __init__(self, var_names_in_use: list[str]): + + self.var_names_in_use: list[str] = var_names_in_use + + self.usage_dict: Mapping[str, list[str]] = {vname: [] for vname in + self.var_names_in_use} + + def map_subscript(self, expr: prim.Subscript) -> None: + + # Check the variable name and if it matches a use case record it. + + name = expr.aggregate.name + if name in self.var_names_in_use: + + self.usage_dict[name].append(str(expr)) + + self.rec(expr.index) + + def map_variable(self, expr: prim.Variable) -> None: + name = expr.name + + if name in self.var_names_in_use: + + self.usage_dict[name].append(str(expr)) + # {{{ AxesTagsEquationCollector @@ -235,66 +261,62 @@ def map_index_lambda(self, expr: IndexLambda) -> None: for bnd in expr.bindings.values(): self.rec(bnd) - try: - hlo = index_lambda_to_high_level_op(expr) - except UnknownIndexLambdaExpr: - from warnings import warn - warn(f"'{expr}' is an unknown index lambda type" - " no tags were propagated across it.", stacklevel=1) - # no propagation semantics implemented for such cases - return + mymap = AxesUsedMapper(list(expr.bindings.keys())) - if isinstance(hlo, BinaryOp): - subexprs: tuple[ArrayOrScalar, ...] = (hlo.x1, hlo.x2) - elif isinstance(hlo, WhereOp): - subexprs = (hlo.condition, hlo.then, hlo.else_) - elif isinstance(hlo, FullOp): - # A full-op does not impose any equations - subexprs = () - elif isinstance(hlo, BroadcastOp): - subexprs = (hlo.x,) - elif isinstance(hlo, C99CallOp): - subexprs = hlo.args - elif isinstance(hlo, ReduceOp): - - # {{{ ReduceOp doesn't quite involve broadcasting + mymap(expr.expr) + + keys = list(expr.bindings.keys()) + + for k in keys: + all_match = True + # Confirm self-consistency. + start = mymap.usage_dict[k][0] + for i in range(1, len(mymap.usage_dict[k])): + if start != mymap.usage_dict[k][i]: + all_match = False + + assert all_match + + out_shape = expr.shape + assert len(out_shape) == expr.ndim + in_shape = [expr.bindings[k].shape for k in keys] + + if expr.var_to_reduction_descr:\ + # We are in a reduction operation and so need to handle differently. i_out_axis = 0 - for i_in_axis in range(hlo.x.ndim): - if i_in_axis not in hlo.axes: - self.record_equation( - self.get_var_for_axis(hlo.x, i_in_axis), - self.get_var_for_axis(expr, i_out_axis) - ) + assert len(in_shape) == 1 # There should be only 1 input. + for i_in_axis in range(len(in_shape[0])): + out_dim = out_shape[i_out_axis] + in_dim = in_shape[0][i_in_axis] + if are_shape_components_equal(in_dim, out_dim): + val = (self.get_var_for_axis(expr.bindings[keys[0]], i_in_axis), + self.get_var_for_axis(expr, i_out_axis)) + self.equations.append(val) i_out_axis += 1 - assert i_out_axis == expr.ndim + if i_out_axis == expr.ndim: + return + return - # }}} + for input_term in range(len(in_shape)): + input_length = len(in_shape[input_term]) - return + for i_in_axis, i_out_axis in zip( + range(input_length), + range(expr.ndim - input_length, expr.ndim)): + in_dim = in_shape[input_term][i_in_axis] + out_dim = out_shape[i_out_axis] + if are_shape_components_equal(in_dim, out_dim): + val = (self.get_var_for_axis(expr.bindings[keys[input_term]], + i_in_axis), + self.get_var_for_axis(expr, i_out_axis)) + self.equations.append(val) - else: - raise NotImplementedError(type(hlo)) + else: + assert are_shape_components_equal(in_dim, 1) - for subexpr in subexprs: - if isinstance(subexpr, Array): - for i_in_axis, i_out_axis in zip( - range(subexpr.ndim), - range(expr.ndim-subexpr.ndim, expr.ndim), - strict=True): - in_dim = subexpr.shape[i_in_axis] - out_dim = expr.shape[i_out_axis] - if are_shape_components_equal(in_dim, out_dim): - self.record_equation( - self.get_var_for_axis(subexpr, i_in_axis), - self.get_var_for_axis(expr, i_out_axis) - ) - else: - # i_in_axis is broadcasted => do not propagate - assert are_shape_components_equal(in_dim, 1) - else: - assert isinstance(subexpr, SCALAR_CLASSES) + return def map_stack(self, expr: Stack) -> None: """ @@ -632,19 +654,20 @@ def rec(self, expr: ArrayOrNames) -> Any: ) if isinstance(expr, IndexLambda): - assert isinstance(expr_copy, IndexLambda) - try: - hlo = index_lambda_to_high_level_op(expr) - except UnknownIndexLambdaExpr: - pass - else: - if isinstance(hlo, ReduceOp): - for iaxis, redn_var in hlo.axes.items(): - expr_copy = expr_copy.with_tagged_reduction( - redn_var, - self.axis_to_tags.get((hlo.x, iaxis), []) - ) + if expr.var_to_reduction_descr: + # This is a reduction operation. + # We need to find the axes that are reduced over + # and update the tag/tag them appropriately. + for iaxis in range(len(expr.expr.inner_expr.index_tuple)): + name = expr.expr.inner_expr.index_tuple[iaxis].name + if name in expr.var_to_reduction_descr.keys(): + assert len(list(expr.bindings.keys())) == 1 + my_arr: Array = next(iter(expr.bindings.values())) + expr_copy = expr_copy.with_tagged_reduction( + name, + self.axis_to_tags.get((my_arr, iaxis), []) + ) # }}} self._cache[key] = expr_copy From 574ab2369611c49703a607ff96e4647f5dd4204e Mon Sep 17 00:00:00 2001 From: Nick Date: Tue, 5 Nov 2024 12:29:00 -0600 Subject: [PATCH 02/28] Remove the usage check. --- pytato/transform/metadata.py | 52 ------------------------------------ 1 file changed, 52 deletions(-) diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 31c2a4a1a..e6aeda372 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -85,44 +85,6 @@ GraphNodeT = TypeVar("GraphNodeT") -import pymbolic.primitives as prim - -from pytato.scalar_expr import ( - IdentityMapper as ScalarMap, -) - - -class AxesUsedMapper(ScalarMap): - """ - Determine which axes are used in the scalar expression and which ones just - flow through the expression. - """ - - def __init__(self, var_names_in_use: list[str]): - - self.var_names_in_use: list[str] = var_names_in_use - - self.usage_dict: Mapping[str, list[str]] = {vname: [] for vname in - self.var_names_in_use} - - def map_subscript(self, expr: prim.Subscript) -> None: - - # Check the variable name and if it matches a use case record it. - - name = expr.aggregate.name - if name in self.var_names_in_use: - - self.usage_dict[name].append(str(expr)) - - self.rec(expr.index) - - def map_variable(self, expr: prim.Variable) -> None: - name = expr.name - - if name in self.var_names_in_use: - - self.usage_dict[name].append(str(expr)) - # {{{ AxesTagsEquationCollector @@ -261,22 +223,8 @@ def map_index_lambda(self, expr: IndexLambda) -> None: for bnd in expr.bindings.values(): self.rec(bnd) - mymap = AxesUsedMapper(list(expr.bindings.keys())) - - mymap(expr.expr) - keys = list(expr.bindings.keys()) - for k in keys: - all_match = True - # Confirm self-consistency. - start = mymap.usage_dict[k][0] - for i in range(1, len(mymap.usage_dict[k])): - if start != mymap.usage_dict[k][i]: - all_match = False - - assert all_match - out_shape = expr.shape assert len(out_shape) == expr.ndim in_shape = [expr.bindings[k].shape for k in keys] From e02a800072d11f64afceb80dca5eca77a8f75713 Mon Sep 17 00:00:00 2001 From: Nick Date: Tue, 5 Nov 2024 13:59:08 -0600 Subject: [PATCH 03/28] Look at the variables directly and as long as we are using the reserved variables we will know the right relationship. --- pytato/transform/metadata.py | 91 +++++++++++++++++++++++------------- 1 file changed, 58 insertions(+), 33 deletions(-) diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index e6aeda372..815ae2dcc 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -85,6 +85,39 @@ GraphNodeT = TypeVar("GraphNodeT") +import pymbolic.primitives as prim +from pytato.scalar_expr import ( + IdentityMapper as ScalarMapper +) +import re + +class AxesUsedMapper(ScalarMapper): + """ + Determine which axes are used in the scalar expressionand which ones just + flow through the expression. + """ + + def __init__(self, var_names_in_use: list[str]): + self.var_names_in_use: list[str] = var_names_in_use + + self.usage_dict: Mapping[str, list[tuple[prim.Expression, ...]]] = {vname: [] \ + for vname in + self.var_names_in_use} + def map_subscript(self, expr: prim.Subscript) -> None: + + name = expr.aggregate.name + if name in self.var_names_in_use: + self.usage_dict[name].append(expr.index_tuple) + + self.rec(expr.index) + + def map_variable(self, expr: prim.Variable) -> None: + name = expr.name + + if name in self.var_names_in_use: + + self.usage_dict[name].append(expr) + # {{{ AxesTagsEquationCollector @@ -225,44 +258,36 @@ def map_index_lambda(self, expr: IndexLambda) -> None: keys = list(expr.bindings.keys()) + mymapper = AxesUsedMapper(keys) + + mymapper(expr.expr) + out_shape = expr.shape assert len(out_shape) == expr.ndim in_shape = [expr.bindings[k].shape for k in keys] - if expr.var_to_reduction_descr:\ - # We are in a reduction operation and so need to handle differently. - - i_out_axis = 0 - assert len(in_shape) == 1 # There should be only 1 input. - for i_in_axis in range(len(in_shape[0])): - out_dim = out_shape[i_out_axis] - in_dim = in_shape[0][i_in_axis] - if are_shape_components_equal(in_dim, out_dim): - val = (self.get_var_for_axis(expr.bindings[keys[0]], i_in_axis), - self.get_var_for_axis(expr, i_out_axis)) - self.equations.append(val) - i_out_axis += 1 - - if i_out_axis == expr.ndim: - return - return - - for input_term in range(len(in_shape)): - input_length = len(in_shape[input_term]) - - for i_in_axis, i_out_axis in zip( - range(input_length), - range(expr.ndim - input_length, expr.ndim)): - in_dim = in_shape[input_term][i_in_axis] - out_dim = out_shape[i_out_axis] - if are_shape_components_equal(in_dim, out_dim): - val = (self.get_var_for_axis(expr.bindings[keys[input_term]], - i_in_axis), - self.get_var_for_axis(expr, i_out_axis)) - self.equations.append(val) + reserved_reduction_pattern = re.compile("^(_r[0-9]+)$") + reserved_iname_pattern = re.compile("^(_[0-9]+)$") + for ikey, key in enumerate(keys): + if len(mymapper.usage_dict[key]) > 0: + for tup_ind in range(len(mymapper.usage_dict[key][0])): + vname = mymapper.usage_dict[key][0][tup_ind] + if isinstance(vname, prim.Variable): + if reserved_reduction_pattern.match(vname.name): + # Reduction axis. We can ignore it. + pass + elif vname.name[:3] == "_in": + # Variable name axis. + pass + elif reserved_iname_pattern.match(vname.name): + # matched with an iname. + inum = int(vname.name[1:]) + val = (self.get_var_for_axis(expr.bindings[key], tup_ind), + self.get_var_for_axis(expr, inum)) + self.equations.append(val) + else: + raise ValueError(f"Unknown index name used in {vname}") - else: - assert are_shape_components_equal(in_dim, 1) return From 09488439d1b6f1a5930de2b936bef85081086690 Mon Sep 17 00:00:00 2001 From: Nick Date: Tue, 5 Nov 2024 14:03:38 -0600 Subject: [PATCH 04/28] Correct ruff suggestions. --- pytato/transform/metadata.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 815ae2dcc..236b120b1 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -85,12 +85,13 @@ GraphNodeT = TypeVar("GraphNodeT") -import pymbolic.primitives as prim -from pytato.scalar_expr import ( - IdentityMapper as ScalarMapper -) import re +import pymbolic.primitives as prim + +from pytato.scalar_expr import IdentityMapper as ScalarMapper + + class AxesUsedMapper(ScalarMapper): """ Determine which axes are used in the scalar expressionand which ones just @@ -100,9 +101,10 @@ class AxesUsedMapper(ScalarMapper): def __init__(self, var_names_in_use: list[str]): self.var_names_in_use: list[str] = var_names_in_use - self.usage_dict: Mapping[str, list[tuple[prim.Expression, ...]]] = {vname: [] \ + self.usage_dict: Mapping[str, list[tuple[prim.Expression, ...]]] = {vname: [] for vname in self.var_names_in_use} + def map_subscript(self, expr: prim.Subscript) -> None: name = expr.aggregate.name @@ -264,11 +266,10 @@ def map_index_lambda(self, expr: IndexLambda) -> None: out_shape = expr.shape assert len(out_shape) == expr.ndim - in_shape = [expr.bindings[k].shape for k in keys] reserved_reduction_pattern = re.compile("^(_r[0-9]+)$") reserved_iname_pattern = re.compile("^(_[0-9]+)$") - for ikey, key in enumerate(keys): + for key in keys: if len(mymapper.usage_dict[key]) > 0: for tup_ind in range(len(mymapper.usage_dict[key][0])): vname = mymapper.usage_dict[key][0][tup_ind] @@ -288,7 +289,6 @@ def map_index_lambda(self, expr: IndexLambda) -> None: else: raise ValueError(f"Unknown index name used in {vname}") - return def map_stack(self, expr: Stack) -> None: From 953a643c2dfa4af470dbbd31c2ab8ac650e40ed7 Mon Sep 17 00:00:00 2001 From: Nick Date: Wed, 6 Nov 2024 15:41:48 -0600 Subject: [PATCH 05/28] Only record usage if the array is indexed in some way. --- pytato/transform/metadata.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 236b120b1..0ee1d921e 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -113,13 +113,6 @@ def map_subscript(self, expr: prim.Subscript) -> None: self.rec(expr.index) - def map_variable(self, expr: prim.Variable) -> None: - name = expr.name - - if name in self.var_names_in_use: - - self.usage_dict[name].append(expr) - # {{{ AxesTagsEquationCollector From 88eac820328f9dc7e38d8042b6cfd53568498de5 Mon Sep 17 00:00:00 2001 From: Nick Date: Wed, 6 Nov 2024 16:47:04 -0600 Subject: [PATCH 06/28] Add a unit test case which is unbroadcastable but is still a legal pytato expression. --- test/test_pytato.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/test/test_pytato.py b/test/test_pytato.py index 271c8fb01..92e555a46 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1324,6 +1324,43 @@ def test_unify_axes_tags(): # }}} +def test_unify_axes_tags_with_unbroadcastable_expressions(): + + a = pt.make_placeholder("a", (512, 10, 8)) + b = pt.make_placeholder("b", (512, 10)) + from testlib import BazTag, FooTag, QuuxTag, TestlibTag + + a = a.with_tagged_axis(0, BazTag()) + a = a.with_tagged_axis(1, QuuxTag()) + a = a.with_tagged_axis(2, FooTag()) + + from immutabledict import immutabledict + + import pymbolic.primitives as prim + + x = prim.Subscript(prim.Variable("_in0"), (prim.Variable("_0"), prim.Variable("_1"), + prim.Variable("_2"))) + y = prim.Subscript(prim.Variable("_in1"), + (prim.Variable("_0"), prim.Variable("_1"))) + + z = pt.IndexLambda(expr=x+y, bindings=immutabledict({"_in0": a, "_in1": b}), + shape=(512, 10, 8), tags=pt.array._get_default_tags(), + axes=pt.array._get_default_axes(3), + dtype=float, + var_to_reduction_descr=immutabledict({})) + + z_unified = pt.unify_axes_tags(z) + + assert (z_unified.axes[0].tags_of_type(TestlibTag) == frozenset([BazTag()])) + assert (z_unified.axes[1].tags_of_type(TestlibTag) == frozenset([QuuxTag()])) + assert (z_unified.axes[2].tags_of_type(TestlibTag) == frozenset([FooTag()])) + + for key in z_unified.bindings.keys(): + term = z_unified.bindings[key] + assert (term.axes[0].tags_of_type(TestlibTag) == frozenset([BazTag()])) + assert (term.axes[1].tags_of_type(TestlibTag) == frozenset([QuuxTag()])) + + def test_rewrite_einsums_with_no_broadcasts(): a = pt.make_placeholder("a", (10, 4, 1)) b = pt.make_placeholder("b", (10, 1, 4)) From 52dc445af416c46bb021aea126f5061cd95cae4f Mon Sep 17 00:00:00 2001 From: Nick Date: Thu, 7 Nov 2024 12:57:53 -0600 Subject: [PATCH 07/28] Add a unit test and split out a reserved pattern for the reductions and the inames in an IndexLambda. --- pytato/scalar_expr.py | 5 +++-- pytato/transform/metadata.py | 12 +++++++----- test/test_pytato.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 38 insertions(+), 7 deletions(-) diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index 032005049..24507134e 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -163,8 +163,9 @@ def map_reduce(self, expr: Reduce) -> ScalarExpression: for name, bound in expr.bounds.items()})) -IDX_LAMBDA_RE = re.compile("_r?(0|([1-9][0-9]*))") - +IDX_LAMBDA_RE = re.compile("^(_r?(0|([1-9][0-9]*)))$") +IDX_LAMBDA_INAME = re.compile("^(_(0|([1-9][0-9]*)))$") +IDX_LAMBDA_JUST_REDUCTIONS = re.compile("^(_r(0|([1-9][0-9]*)))$") class DependencyMapper(DependencyMapperBase[P]): def __init__(self, *, diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 0ee1d921e..8834b0f81 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -89,7 +89,11 @@ import pymbolic.primitives as prim -from pytato.scalar_expr import IdentityMapper as ScalarMapper +from pytato.scalar_expr import ( + IdentityMapper as ScalarMapper, + IDX_LAMBDA_JUST_REDUCTIONS, + IDX_LAMBDA_INAME +) class AxesUsedMapper(ScalarMapper): @@ -260,20 +264,18 @@ def map_index_lambda(self, expr: IndexLambda) -> None: out_shape = expr.shape assert len(out_shape) == expr.ndim - reserved_reduction_pattern = re.compile("^(_r[0-9]+)$") - reserved_iname_pattern = re.compile("^(_[0-9]+)$") for key in keys: if len(mymapper.usage_dict[key]) > 0: for tup_ind in range(len(mymapper.usage_dict[key][0])): vname = mymapper.usage_dict[key][0][tup_ind] if isinstance(vname, prim.Variable): - if reserved_reduction_pattern.match(vname.name): + if IDX_LAMBDA_JUST_REDUCTIONS.fullmatch(vname.name): # Reduction axis. We can ignore it. pass elif vname.name[:3] == "_in": # Variable name axis. pass - elif reserved_iname_pattern.match(vname.name): + elif IDX_LAMBDA_INAME.fullmatch(vname.name): # matched with an iname. inum = int(vname.name[1:]) val = (self.get_var_for_axis(expr.bindings[key], tup_ind), diff --git a/test/test_pytato.py b/test/test_pytato.py index 92e555a46..66711a970 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1213,6 +1213,34 @@ def test_lower_to_index_lambda(): assert isinstance(binding, Reshape) +def test_reserved_scalar_iname_patterns(): + from pytato.scalar_expr import ( + IDX_LAMBDA_INAME, + IDX_LAMBDA_JUST_REDUCTIONS, + IDX_LAMBDA_RE, + ) + + test_strings = ["_r0", "_r000", "_r01", "_00", "_r101", "_1", "_0", "_101"] + + assert IDX_LAMBDA_RE.fullmatch(test_strings[0]) + assert not IDX_LAMBDA_INAME.fullmatch(test_strings[0]) + assert IDX_LAMBDA_JUST_REDUCTIONS.fullmatch(test_strings[0]) + + for pat in [IDX_LAMBDA_INAME, IDX_LAMBDA_RE, IDX_LAMBDA_JUST_REDUCTIONS]: + assert not pat.fullmatch(test_strings[1]) + assert not pat.fullmatch(test_strings[2]) + assert not pat.fullmatch(test_strings[3]) + + assert IDX_LAMBDA_RE.fullmatch(test_strings[4]) + assert not IDX_LAMBDA_INAME.fullmatch(test_strings[4]) + assert IDX_LAMBDA_JUST_REDUCTIONS.fullmatch(test_strings[4]) + + for i in range(5, len(test_strings)): + assert IDX_LAMBDA_RE.fullmatch(test_strings[i]) + assert IDX_LAMBDA_INAME.fullmatch(test_strings[i]) + assert not IDX_LAMBDA_JUST_REDUCTIONS.fullmatch(test_strings[i]) + + def test_cached_walk_mapper_with_extra_args(): from testlib import RandomDAGContext, make_random_dag From 1031868adacf3b355e69aa049d659af342360ad3 Mon Sep 17 00:00:00 2001 From: Nick Date: Thu, 7 Nov 2024 13:23:42 -0600 Subject: [PATCH 08/28] Fix ruff suggestions. --- pytato/transform/metadata.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 8834b0f81..2d95f56f9 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -85,14 +85,13 @@ GraphNodeT = TypeVar("GraphNodeT") -import re import pymbolic.primitives as prim from pytato.scalar_expr import ( - IdentityMapper as ScalarMapper, + IDX_LAMBDA_INAME, IDX_LAMBDA_JUST_REDUCTIONS, - IDX_LAMBDA_INAME + IdentityMapper as ScalarMapper, ) From 85e5395e95ee989457ce218d87bf6df8a6a84466 Mon Sep 17 00:00:00 2001 From: Nick Date: Thu, 7 Nov 2024 13:25:00 -0600 Subject: [PATCH 09/28] More ruff suggestions. --- pytato/scalar_expr.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index 24507134e..43a8fc252 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -167,6 +167,7 @@ def map_reduce(self, expr: Reduce) -> ScalarExpression: IDX_LAMBDA_INAME = re.compile("^(_(0|([1-9][0-9]*)))$") IDX_LAMBDA_JUST_REDUCTIONS = re.compile("^(_r(0|([1-9][0-9]*)))$") + class DependencyMapper(DependencyMapperBase[P]): def __init__(self, *, include_idx_lambda_indices: bool = True, From a50d58e07be23bd65aa81ed15613bc116cefd7dd Mon Sep 17 00:00:00 2001 From: Nick Date: Thu, 7 Nov 2024 16:29:03 -0600 Subject: [PATCH 10/28] Make sure that we return a value if we need to. :) --- pytato/transform/metadata.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 2d95f56f9..28b832964 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -108,13 +108,17 @@ def __init__(self, var_names_in_use: list[str]): for vname in self.var_names_in_use} - def map_subscript(self, expr: prim.Subscript) -> None: + def map_subscript(self, expr: prim.Subscript) -> prim.Subscript: + """ + Record the indexing usage for the variable if we are tracking + the specific variable. + """ name = expr.aggregate.name if name in self.var_names_in_use: self.usage_dict[name].append(expr.index_tuple) - self.rec(expr.index) + return super().map_subscript(expr) # {{{ AxesTagsEquationCollector @@ -277,6 +281,7 @@ def map_index_lambda(self, expr: IndexLambda) -> None: elif IDX_LAMBDA_INAME.fullmatch(vname.name): # matched with an iname. inum = int(vname.name[1:]) + print(inum) val = (self.get_var_for_axis(expr.bindings[key], tup_ind), self.get_var_for_axis(expr, inum)) self.equations.append(val) From 921b55f95d5a1a2ba2b6ebc9e1c9035784bc8634 Mon Sep 17 00:00:00 2001 From: Nick Date: Mon, 25 Nov 2024 11:10:20 -0600 Subject: [PATCH 11/28] Working on mypy errors. --- pytato/transform/metadata.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 28b832964..7d650ea73 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -86,6 +86,8 @@ GraphNodeT = TypeVar("GraphNodeT") +from typing import ParamSpec + import pymbolic.primitives as prim from pytato.scalar_expr import ( @@ -95,7 +97,10 @@ ) -class AxesUsedMapper(ScalarMapper): +P = ParamSpec("P") + + +class AxesUsedMapper(ScalarMapper[P]): """ Determine which axes are used in the scalar expressionand which ones just flow through the expression. From e858110226bab474c578f663024c76c98b42c5b9 Mon Sep 17 00:00:00 2001 From: Nick Date: Wed, 11 Dec 2024 16:55:04 -0600 Subject: [PATCH 12/28] Respond to comments. --- pytato/transform/metadata.py | 98 +++++++++++++++++++----------------- 1 file changed, 53 insertions(+), 45 deletions(-) diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 7d650ea73..073a6413e 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -86,44 +86,58 @@ GraphNodeT = TypeVar("GraphNodeT") -from typing import ParamSpec +from collections.abc import Iterable +from typing import ParamSpec, TypeAlias import pymbolic.primitives as prim +from pymbolic.typing import Expression from pytato.scalar_expr import ( IDX_LAMBDA_INAME, IDX_LAMBDA_JUST_REDUCTIONS, - IdentityMapper as ScalarMapper, + CombineMapper, ) +BindingName: TypeAlias = str P = ParamSpec("P") -class AxesUsedMapper(ScalarMapper[P]): +class IndexExpressionsUsedInIndexLambda(CombineMapper[Mapping[BindingName, + set[tuple[Expression, ...]]], + P]): """ Determine which axes are used in the scalar expressionand which ones just flow through the expression. """ - - def __init__(self, var_names_in_use: list[str]): - self.var_names_in_use: list[str] = var_names_in_use - - self.usage_dict: Mapping[str, list[tuple[prim.Expression, ...]]] = {vname: [] - for vname in - self.var_names_in_use} - - def map_subscript(self, expr: prim.Subscript) -> prim.Subscript: + def combine(self, + values: Iterable[Mapping[BindingName, set[tuple[Expression, ...]]]]) \ + -> Mapping[BindingName, set[tuple[Expression, ...]]]: + out: dict[BindingName, set[tuple[Expression, ...]]] = {} + for val in values: + out.update(val) + return out + + def map_subscript(self, expr: prim.Subscript) -> Mapping[BindingName, + set[tuple[Expression, ...]]]: """ Record the indexing usage for the variable if we are tracking the specific variable. """ name = expr.aggregate.name - if name in self.var_names_in_use: - self.usage_dict[name].append(expr.index_tuple) + base = {name: expr.index_tuple} + + return self.combine([base, self.rec(expr.index)]) - return super().map_subscript(expr) + def map_constant(self, expr: object) -> Mapping[BindingName, + set[tuple[Expression, ...]]]: + return {} + + def map_algebraic_leaf(self, expr: prim.ExpressionNode) -> Mapping[BindingName, + set[tuple[Expression, ...]]]: + + return {} # {{{ AxesTagsEquationCollector @@ -263,36 +277,30 @@ def map_index_lambda(self, expr: IndexLambda) -> None: for bnd in expr.bindings.values(): self.rec(bnd) - keys = list(expr.bindings.keys()) - - mymapper = AxesUsedMapper(keys) - - mymapper(expr.expr) - - out_shape = expr.shape - assert len(out_shape) == expr.ndim - - for key in keys: - if len(mymapper.usage_dict[key]) > 0: - for tup_ind in range(len(mymapper.usage_dict[key][0])): - vname = mymapper.usage_dict[key][0][tup_ind] - if isinstance(vname, prim.Variable): - if IDX_LAMBDA_JUST_REDUCTIONS.fullmatch(vname.name): - # Reduction axis. We can ignore it. - pass - elif vname.name[:3] == "_in": - # Variable name axis. - pass - elif IDX_LAMBDA_INAME.fullmatch(vname.name): - # matched with an iname. - inum = int(vname.name[1:]) - print(inum) - val = (self.get_var_for_axis(expr.bindings[key], tup_ind), - self.get_var_for_axis(expr, inum)) - self.equations.append(val) - else: - raise ValueError(f"Unknown index name used in {vname}") - + index_expr_used = IndexExpressionsUsedInIndexLambda()(expr.expr) + + if __debug__: + out_shape = expr.shape + assert len(out_shape) == expr.ndim + + for vname, ind_tuple in index_expr_used.items(): + for axis_ind in range(len(ind_tuple)): + var_ind_name = ind_tuple[axis_ind] + if isinstance(var_ind_name, prim.Variable): + if IDX_LAMBDA_JUST_REDUCTIONS.fullmatch(var_ind_name.name): + # Reduction axis. We can ignore it. + pass + elif var_ind_name.name[:3] == "_in": + # Variable name axis. + pass + elif IDX_LAMBDA_INAME.fullmatch(var_ind_name.name): + # matched with an iname. + inum = int(var_ind_name.name[1:]) + val = (self.get_var_for_axis(expr.bindings[vname], axis_ind), + self.get_var_for_axis(expr, inum)) + self.equations.append(val) + else: + raise ValueError(f"Unknown index name used in {vname}") return def map_stack(self, expr: Stack) -> None: From 17df8719b027ae1514769bca6a2b1ecd5da6565e Mon Sep 17 00:00:00 2001 From: Nick Date: Wed, 11 Dec 2024 17:26:54 -0600 Subject: [PATCH 13/28] Update for ruff. --- pytato/scalar_expr.py | 6 +++--- pytato/transform/metadata.py | 6 ++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index 377e46262..7861e3504 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -163,9 +163,9 @@ def map_reduce(self, expr: Reduce) -> ScalarExpression: for name, bound in expr.bounds.items()})) -IDX_LAMBDA_RE = re.compile("^(_r?(0|([1-9][0-9]*)))$") -IDX_LAMBDA_INAME = re.compile("^(_(0|([1-9][0-9]*)))$") -IDX_LAMBDA_JUST_REDUCTIONS = re.compile("^(_r(0|([1-9][0-9]*)))$") +IDX_LAMBDA_RE = re.compile(r"^(_r?(0|([1-9][0-9]*)))$") +IDX_LAMBDA_INAME = re.compile(r"^(_(0|([1-9][0-9]*)))$") +IDX_LAMBDA_JUST_REDUCTIONS = re.compile(r"^(_r(0|([1-9][0-9]*)))$") class DependencyMapper(DependencyMapperBase[P]): diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index aeb486e23..3ddcaef75 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -43,8 +43,12 @@ from typing import ( TYPE_CHECKING, Any, + Mapping, TypeVar, + Iterable, + TypeAlias, cast, + ParamSpec, ) from bidict import bidict @@ -88,8 +92,6 @@ GraphNodeT = TypeVar("GraphNodeT") -from collections.abc import Iterable -from typing import ParamSpec, TypeAlias import pymbolic.primitives as prim from pymbolic.typing import Expression From 5b01c24b1e85a0851d48ffa2abc1aa49a92aad99 Mon Sep 17 00:00:00 2001 From: Nick Date: Wed, 11 Dec 2024 17:36:05 -0600 Subject: [PATCH 14/28] Move typing information to only import if type checking. --- pytato/transform/metadata.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 3ddcaef75..c222ee883 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -43,9 +43,7 @@ from typing import ( TYPE_CHECKING, Any, - Mapping, TypeVar, - Iterable, TypeAlias, cast, ParamSpec, @@ -78,12 +76,20 @@ from pytato.transform import ArrayOrNames, CopyMapper, Mapper from pytato.utils import are_shape_components_equal, are_shapes_equal +import pymbolic.primitives as prim +from pymbolic.typing import Expression + +from pytato.scalar_expr import ( + IDX_LAMBDA_INAME, + IDX_LAMBDA_JUST_REDUCTIONS, + CombineMapper, +) logger = logging.getLogger(__name__) if TYPE_CHECKING: - from collections.abc import Collection, Mapping + from collections.abc import Collection, Mapping, Iterable from pytato.function import NamedCallResult from pytato.loopy import LoopyCall @@ -91,18 +97,6 @@ GraphNodeT = TypeVar("GraphNodeT") - - -import pymbolic.primitives as prim -from pymbolic.typing import Expression - -from pytato.scalar_expr import ( - IDX_LAMBDA_INAME, - IDX_LAMBDA_JUST_REDUCTIONS, - CombineMapper, -) - - BindingName: TypeAlias = str P = ParamSpec("P") From 882312f739bbe30cd5bc2e3bab618384917985c4 Mon Sep 17 00:00:00 2001 From: Nick Date: Wed, 11 Dec 2024 22:05:57 -0600 Subject: [PATCH 15/28] More ruff CI. --- pytato/transform/metadata.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index c222ee883..c6f52d95e 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -40,20 +40,7 @@ import logging -from typing import ( - TYPE_CHECKING, - Any, - TypeVar, - TypeAlias, - cast, - ParamSpec, -) - from bidict import bidict - -from pytools import UniqueNameGenerator -from pytools.tag import Tag - from pytato.array import ( AbstractResultWithNamedArrays, AdvancedIndexInContiguousAxes, @@ -71,6 +58,7 @@ Reshape, Stack, ) + from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder from pytato.function import NamedCallResult from pytato.transform import ArrayOrNames, CopyMapper, Mapper @@ -84,7 +72,17 @@ IDX_LAMBDA_JUST_REDUCTIONS, CombineMapper, ) +from pytools import UniqueNameGenerator +from pytools.tag import Tag +from typing import ( + TYPE_CHECKING, + Any, + TypeVar, + TypeAlias, + cast, + ParamSpec, +) logger = logging.getLogger(__name__) From 0feea14344d8203eb345b7e4d98bcc7ace065fce Mon Sep 17 00:00:00 2001 From: Nick Date: Wed, 11 Dec 2024 22:12:28 -0600 Subject: [PATCH 16/28] Add noqa: RUF052 for kernels in test_codegen.py. --- test/test_codegen.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_codegen.py b/test/test_codegen.py index 5b62ce468..558ad2afb 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -917,10 +917,10 @@ def test_einsum_with_parameterized_shapes(ctx_factory): n_in = np.random.randint(2, 20) def _get_a_shape(_m, _n): - return (2*_m+1, 3*_n+7) + return (2*_m+1, 3*_n+7) # noqa: RUF052 def _get_x_shape(_m, _n): - return (3*_n+7, ) + return (3*_n+7, ) # noqa: RUF052 A_in = np.random.rand(*_get_a_shape(m_in, n_in)) # noqa: N806 x_in = np.random.rand(*_get_x_shape(m_in, n_in)) @@ -1571,7 +1571,7 @@ def test_regression_reduction_in_conditional(ctx_factory): cq = cl.CommandQueue(ctx) def kernel(usr_np, _pt_data_9): - pt_tmp_53 = _pt_data_9 @ _pt_data_9 + pt_tmp_53 = _pt_data_9 @ _pt_data_9 # NOQA RUF052 pt_tmp_42 = usr_np.maximum(pt_tmp_53, pt_tmp_53) pt_tmp_27 = usr_np.sum(pt_tmp_42) pt_tmp_0 = usr_np.maximum(pt_tmp_27, pt_tmp_53) From 8848fd5e8abfea5655043663b9cc5cab370a351c Mon Sep 17 00:00:00 2001 From: Nick Date: Wed, 11 Dec 2024 22:39:07 -0600 Subject: [PATCH 17/28] Add assert statements for typing purposes. --- pytato/transform/metadata.py | 23 +++++++++++++---------- test/test_codegen.py | 4 ++-- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index c6f52d95e..996646eb9 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -61,17 +61,17 @@ from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder from pytato.function import NamedCallResult +from pytato.scalar_expr import ( + IDX_LAMBDA_INAME, + IDX_LAMBDA_JUST_REDUCTIONS, + CombineMapper, +) from pytato.transform import ArrayOrNames, CopyMapper, Mapper from pytato.utils import are_shape_components_equal, are_shapes_equal import pymbolic.primitives as prim from pymbolic.typing import Expression -from pytato.scalar_expr import ( - IDX_LAMBDA_INAME, - IDX_LAMBDA_JUST_REDUCTIONS, - CombineMapper, -) from pytools import UniqueNameGenerator from pytools.tag import Tag @@ -101,7 +101,7 @@ class IndexExpressionsUsedInIndexLambda(CombineMapper[Mapping[BindingName, set[tuple[Expression, ...]]], - P]): + []]): """ Determine which axes are used in the scalar expressionand which ones just flow through the expression. @@ -122,9 +122,11 @@ def map_subscript(self, expr: prim.Subscript) -> Mapping[BindingName, """ name = expr.aggregate.name - base = {name: expr.index_tuple} - - return self.combine([base, self.rec(expr.index)]) + if isinstance(expr.aggregrate, prim.Variable): + name = expr.aggregate.name + base = {name: set((expr.index_tuple))} + return self.combine([base, self.rec(expr.index)]) + return {} def map_constant(self, expr: object) -> Mapping[BindingName, set[tuple[Expression, ...]]]: @@ -644,7 +646,8 @@ def rec(self, expr: ArrayOrNames) -> Any: if name in expr.var_to_reduction_descr.keys(): assert len(list(expr.bindings.keys())) == 1 my_arr: Array = next(iter(expr.bindings.values())) - + + assert isinstance(expr_copy, IndexLambda) expr_copy = expr_copy.with_tagged_reduction( name, self.axis_to_tags.get((my_arr, iaxis), []) diff --git a/test/test_codegen.py b/test/test_codegen.py index 558ad2afb..ac098e981 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -917,10 +917,10 @@ def test_einsum_with_parameterized_shapes(ctx_factory): n_in = np.random.randint(2, 20) def _get_a_shape(_m, _n): - return (2*_m+1, 3*_n+7) # noqa: RUF052 + return (2*_m+1, 3*_n+7) # noqa: RUF052 def _get_x_shape(_m, _n): - return (3*_n+7, ) # noqa: RUF052 + return (3*_n+7, ) # noqa: RUF052 A_in = np.random.rand(*_get_a_shape(m_in, n_in)) # noqa: N806 x_in = np.random.rand(*_get_x_shape(m_in, n_in)) From a30b3bce803c5a2fa9c2af8dcaa2ce4d656c1636 Mon Sep 17 00:00:00 2001 From: Nick Date: Wed, 11 Dec 2024 22:55:16 -0600 Subject: [PATCH 18/28] Reorganize. Ruff was out of date. :) --- pytato/transform/metadata.py | 47 ++++++++++++++++++------------------ test/test_codegen.py | 12 ++++----- 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 996646eb9..063c39b0d 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -40,7 +40,23 @@ import logging +from collections.abc import Mapping +from typing import ( + TYPE_CHECKING, + Any, + ParamSpec, + TypeAlias, + TypeVar, + cast, +) + from bidict import bidict + +import pymbolic.primitives as prim +from pymbolic.typing import Expression +from pytools import UniqueNameGenerator +from pytools.tag import Tag + from pytato.array import ( AbstractResultWithNamedArrays, AdvancedIndexInContiguousAxes, @@ -58,7 +74,6 @@ Reshape, Stack, ) - from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder from pytato.function import NamedCallResult from pytato.scalar_expr import ( @@ -69,25 +84,12 @@ from pytato.transform import ArrayOrNames, CopyMapper, Mapper from pytato.utils import are_shape_components_equal, are_shapes_equal -import pymbolic.primitives as prim -from pymbolic.typing import Expression - -from pytools import UniqueNameGenerator -from pytools.tag import Tag -from typing import ( - TYPE_CHECKING, - Any, - TypeVar, - TypeAlias, - cast, - ParamSpec, -) logger = logging.getLogger(__name__) if TYPE_CHECKING: - from collections.abc import Collection, Mapping, Iterable + from collections.abc import Collection, Iterable from pytato.function import NamedCallResult from pytato.loopy import LoopyCall @@ -121,20 +123,19 @@ def map_subscript(self, expr: prim.Subscript) -> Mapping[BindingName, the specific variable. """ - name = expr.aggregate.name - if isinstance(expr.aggregrate, prim.Variable): - name = expr.aggregate.name - base = {name: set((expr.index_tuple))} + if isinstance(expr.aggregate, prim.Variable): + name: BindingName = expr.aggregate.name + base = {name: set(expr.index_tuple)} return self.combine([base, self.rec(expr.index)]) return {} - def map_constant(self, expr: object) -> Mapping[BindingName, + def map_algebraic_leaf(self, expr: prim.ExpressionNode) -> Mapping[BindingName, set[tuple[Expression, ...]]]: + return {} - def map_algebraic_leaf(self, expr: prim.ExpressionNode) -> Mapping[BindingName, + def map_constant(self, expr: object) -> Mapping[BindingName, set[tuple[Expression, ...]]]: - return {} @@ -646,7 +647,7 @@ def rec(self, expr: ArrayOrNames) -> Any: if name in expr.var_to_reduction_descr.keys(): assert len(list(expr.bindings.keys())) == 1 my_arr: Array = next(iter(expr.bindings.values())) - + assert isinstance(expr_copy, IndexLambda) expr_copy = expr_copy.with_tagged_reduction( name, diff --git a/test/test_codegen.py b/test/test_codegen.py index ac098e981..b2827130a 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -916,11 +916,11 @@ def test_einsum_with_parameterized_shapes(ctx_factory): m_in = np.random.randint(2, 20) n_in = np.random.randint(2, 20) - def _get_a_shape(_m, _n): - return (2*_m+1, 3*_n+7) # noqa: RUF052 + def _get_a_shape(m_, n_): + return (2*m_+1, 3*n_+7) - def _get_x_shape(_m, _n): - return (3*_n+7, ) # noqa: RUF052 + def _get_x_shape(_m, n_): + return (3*n_+7, ) A_in = np.random.rand(*_get_a_shape(m_in, n_in)) # noqa: N806 x_in = np.random.rand(*_get_x_shape(m_in, n_in)) @@ -1570,8 +1570,8 @@ def test_regression_reduction_in_conditional(ctx_factory): ctx = ctx_factory() cq = cl.CommandQueue(ctx) - def kernel(usr_np, _pt_data_9): - pt_tmp_53 = _pt_data_9 @ _pt_data_9 # NOQA RUF052 + def kernel(usr_np, pt_data_9): + pt_tmp_53 = pt_data_9 @ pt_data_9 pt_tmp_42 = usr_np.maximum(pt_tmp_53, pt_tmp_53) pt_tmp_27 = usr_np.sum(pt_tmp_42) pt_tmp_0 = usr_np.maximum(pt_tmp_27, pt_tmp_53) From 07bc5ab1daea8cfd4ce7d5fa33a657c22fa0bf50 Mon Sep 17 00:00:00 2001 From: Nick Date: Thu, 12 Dec 2024 00:15:58 -0600 Subject: [PATCH 19/28] Fix some of the mypy errors. --- pytato/transform/metadata.py | 40 +++++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 063c39b0d..e99a21b6d 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -53,7 +53,6 @@ from bidict import bidict import pymbolic.primitives as prim -from pymbolic.typing import Expression from pytools import UniqueNameGenerator from pytools.tag import Tag @@ -102,22 +101,23 @@ class IndexExpressionsUsedInIndexLambda(CombineMapper[Mapping[BindingName, - set[tuple[Expression, ...]]], + set[tuple[prim.Variable, ...]]], []]): """ Determine which axes are used in the scalar expressionand which ones just flow through the expression. """ def combine(self, - values: Iterable[Mapping[BindingName, set[tuple[Expression, ...]]]]) \ - -> Mapping[BindingName, set[tuple[Expression, ...]]]: - out: dict[BindingName, set[tuple[Expression, ...]]] = {} + values: Iterable[Mapping[BindingName, + set[tuple[prim.Variable, ...]]]]) \ + -> Mapping[BindingName, set[tuple[prim.Variable, ...]]]: + out: dict[BindingName, set[tuple[prim.Variable, ...]]] = {} for val in values: out.update(val) return out def map_subscript(self, expr: prim.Subscript) -> Mapping[BindingName, - set[tuple[Expression, ...]]]: + set[tuple[prim.Variable, ...]]]: """ Record the indexing usage for the variable if we are tracking the specific variable. @@ -125,17 +125,22 @@ def map_subscript(self, expr: prim.Subscript) -> Mapping[BindingName, if isinstance(expr.aggregate, prim.Variable): name: BindingName = expr.aggregate.name - base = {name: set(expr.index_tuple)} + + index = (val if isinstance(val, prim.Variable) + else prim.Variable(name="IGNORE") + for val in expr.index_tuple) + base: Mapping[BindingName, set[tuple[prim.Variable, ...]]] = {name: + set([tuple(index)])} return self.combine([base, self.rec(expr.index)]) return {} def map_algebraic_leaf(self, expr: prim.ExpressionNode) -> Mapping[BindingName, - set[tuple[Expression, ...]]]: + set[tuple[prim.Variable, ...]]]: return {} def map_constant(self, expr: object) -> Mapping[BindingName, - set[tuple[Expression, ...]]]: + set[tuple[prim.Variable, ...]]]: return {} @@ -282,22 +287,25 @@ def map_index_lambda(self, expr: IndexLambda) -> None: out_shape = expr.shape assert len(out_shape) == expr.ndim - for vname, ind_tuple in index_expr_used.items(): - for axis_ind in range(len(ind_tuple)): - var_ind_name = ind_tuple[axis_ind] - if isinstance(var_ind_name, prim.Variable): + for vname, set_of_ind_tuple in index_expr_used.items(): + for ind_tuple in set_of_ind_tuple: + for axis_ind in range(len(ind_tuple)): + var_ind_name = ind_tuple[axis_ind] if IDX_LAMBDA_JUST_REDUCTIONS.fullmatch(var_ind_name.name): # Reduction axis. We can ignore it. pass elif var_ind_name.name[:3] == "_in": # Variable name axis. pass + elif var_ind_name.name == "IGNORE": + # This is not directly represented in output axes. Ignore. + pass elif IDX_LAMBDA_INAME.fullmatch(var_ind_name.name): # matched with an iname. inum = int(var_ind_name.name[1:]) - val = (self.get_var_for_axis(expr.bindings[vname], axis_ind), - self.get_var_for_axis(expr, inum)) - self.equations.append(val) + lhs: str = self.get_var_for_axis(expr.bindings[vname], axis_ind) + rhs: str = self.get_var_for_axis(expr, inum) + self.record_equation(lhs, rhs) else: raise ValueError(f"Unknown index name used in {vname}") return From 6e73ed6c0078b5862b6b2628c8c946dd3aebc85e Mon Sep 17 00:00:00 2001 From: Nick Date: Thu, 19 Dec 2024 23:00:51 -0600 Subject: [PATCH 20/28] Add a mapper for applying the updates in the case of a reduction operator working on an IndexLambda. --- pytato/transform/metadata.py | 65 ++++++++++++++++++++++++++++-------- 1 file changed, 51 insertions(+), 14 deletions(-) diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index e99a21b6d..4c6141009 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -79,6 +79,7 @@ IDX_LAMBDA_INAME, IDX_LAMBDA_JUST_REDUCTIONS, CombineMapper, + IdentityMapper as ScalarIdentityMapper, ) from pytato.transform import ArrayOrNames, CopyMapper, Mapper from pytato.utils import are_shape_components_equal, are_shapes_equal @@ -99,6 +100,8 @@ BindingName: TypeAlias = str P = ParamSpec("P") +# {{{ IndexExpressionsUsedInIndexLambda + class IndexExpressionsUsedInIndexLambda(CombineMapper[Mapping[BindingName, set[tuple[prim.Variable, ...]]], @@ -126,11 +129,11 @@ def map_subscript(self, expr: prim.Subscript) -> Mapping[BindingName, if isinstance(expr.aggregate, prim.Variable): name: BindingName = expr.aggregate.name - index = (val if isinstance(val, prim.Variable) + index = tuple(val if isinstance(val, prim.Variable) else prim.Variable(name="IGNORE") for val in expr.index_tuple) base: Mapping[BindingName, set[tuple[prim.Variable, ...]]] = {name: - set([tuple(index)])} + {index}} return self.combine([base, self.rec(expr.index)]) return {} @@ -142,6 +145,43 @@ def map_algebraic_leaf(self, expr: prim.ExpressionNode) -> Mapping[BindingName, def map_constant(self, expr: object) -> Mapping[BindingName, set[tuple[prim.Variable, ...]]]: return {} +# }}} + +# {{{ Tag Reduction expressions + + +class TagReductionAxesMapper(ScalarIdentityMapper[[]]): + + def __init__(self, + axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]], + array_expr: IndexLambda): + self.axis_to_tags = axis_to_tags + self.array_expr = array_expr + + def map_subscript(self, + expr: prim.Subscript, + #*args: P.args, + #**kwargs: P.kwargs) -> prim.Expression: + ) -> prim.Expression: + + if isinstance(expr.aggregate, prim.Variable): + name: BindingName = expr.aggregate.name + + for iaxis, val in enumerate(expr.index_tuple): + if isinstance(val, prim.Variable): + if val.name in self.array_expr.var_to_reduction_descr.keys() and \ + name in self.array_expr.bindings.keys(): + # We matched the reduction axis. + # Now we need to add the tag to the original expression. + my_key = (self.array_expr.bindings[name], iaxis) + self.array_expr = self.array_expr.with_tagged_reduction( + name, self.axis_to_tags.get(my_key, []) + ) + + assert isinstance(self.array_expr, IndexLambda) + #return super().map_subscript(expr, *args, **kwargs) + return super().map_subscript(expr) +# }}} # {{{ AxesTagsEquationCollector @@ -645,22 +685,19 @@ def rec(self, expr: ArrayOrNames) -> Any: self.axis_to_tags.get((arg, iaxis), []) ) - if isinstance(expr, IndexLambda): - if expr.var_to_reduction_descr: + if isinstance(expr_copy, IndexLambda): + if expr_copy.var_to_reduction_descr: # This is a reduction operation. # We need to find the axes that are reduced over # and update the tag/tag them appropriately. - for iaxis in range(len(expr.expr.inner_expr.index_tuple)): - name = expr.expr.inner_expr.index_tuple[iaxis].name - if name in expr.var_to_reduction_descr.keys(): - assert len(list(expr.bindings.keys())) == 1 - my_arr: Array = next(iter(expr.bindings.values())) + mymapper: TagReductionAxesMapper = \ + TagReductionAxesMapper( + self.axis_to_tags, + expr_copy + ) + mymapper(expr_copy.expr) # Tag the axes + expr_copy = mymapper.array_expr # Recover it. - assert isinstance(expr_copy, IndexLambda) - expr_copy = expr_copy.with_tagged_reduction( - name, - self.axis_to_tags.get((my_arr, iaxis), []) - ) # }}} self._cache[key] = expr_copy From 0336c0054f226f65f3afe4d8b4a678832f5e1753 Mon Sep 17 00:00:00 2001 From: Nick Date: Thu, 19 Dec 2024 23:15:46 -0600 Subject: [PATCH 21/28] Fix the ruff comments. --- test/test_pytato.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_pytato.py b/test/test_pytato.py index 70dfc3fee..a38e4961f 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1388,6 +1388,7 @@ def test_unify_axes_tags_with_unbroadcastable_expressions(): assert (term.axes[0].tags_of_type(TestlibTag) == frozenset([BazTag()])) assert (term.axes[1].tags_of_type(TestlibTag) == frozenset([QuuxTag()])) + def test_ignoring_axes_during_propagation(): from pytools.tag import UniqueTag From 5242f3fe9d36fbcb5ec1d5a7247ce8ff94f1863a Mon Sep 17 00:00:00 2001 From: Nick Date: Fri, 20 Dec 2024 00:02:29 -0600 Subject: [PATCH 22/28] Update the test code to use the correct name of the argument of its internal function. --- test/test_codegen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_codegen.py b/test/test_codegen.py index b2827130a..6784be948 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -1579,7 +1579,7 @@ def kernel(usr_np, pt_data_9): def get_np_input_args(): return { - "_pt_data_9": np.ones((2, 2)), + "pt_data_9": np.ones((2, 2)), } np_inputs = get_np_input_args() From e7e750a6230cdea0d7a787d57115d4e6105cd149 Mon Sep 17 00:00:00 2001 From: Nick Koskelo Date: Wed, 8 Jan 2025 19:49:11 +0000 Subject: [PATCH 23/28] Add a test case for the pattern match on binding names. Use the reduce function. --- pytato/transform/metadata.py | 78 ++++++++++++++++++++---------------- test/test_pytato.py | 14 +++++++ 2 files changed, 58 insertions(+), 34 deletions(-) diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index e4cffa2a2..4eb02b38a 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -38,9 +38,9 @@ THE SOFTWARE. """ - import logging -from collections.abc import Mapping +import re +from dataclasses import dataclass from typing import ( TYPE_CHECKING, Any, @@ -89,7 +89,7 @@ if TYPE_CHECKING: - from collections.abc import Collection, Iterable + from collections.abc import Collection, Iterable, Mapping from pytato.function import NamedCallResult from pytato.loopy import LoopyCall @@ -100,49 +100,65 @@ BindingName: TypeAlias = str P = ParamSpec("P") -# {{{ IndexExpressionsUsedInIndexLambda +IGNORE_VARIABLE_STR: str = "IGNORE" +BINDING_NAME_RESERVED_PATTERN = re.compile(r"^(_in?(0|([1-9][0-9]*)))$") + +# {{{ BindingSubscriptsUsedInIndexLambda -class IndexExpressionsUsedInIndexLambda(CombineMapper[Mapping[BindingName, +class BindingSubscriptsUsedInIndexLambda(CombineMapper[dict[BindingName, set[tuple[prim.Variable, ...]]], []]): """ - Determine which axes are used in the scalar expressionand which ones just - flow through the expression. + Return all the subscript expressions used by a variable specified by BindingName. + + Ex: + _in1[_0,_1] would result in an dictionary entry {"_in1": ("_0", "_1")}. + + In the case that a subscript expression is not a variable, like in + + Ex: + _in1[_0, 0] + + that subscript will be replaced with a `Variable` with the name IGNORE. + + So the second example would result in an dictionary entry + {"_in1": ("_0","IGNORE")}. """ def combine(self, - values: Iterable[Mapping[BindingName, + values: Iterable[dict[BindingName, set[tuple[prim.Variable, ...]]]]) \ - -> Mapping[BindingName, set[tuple[prim.Variable, ...]]]: + -> dict[BindingName, set[tuple[prim.Variable, ...]]]: out: dict[BindingName, set[tuple[prim.Variable, ...]]] = {} - for val in values: - out.update(val) - return out + from functools import reduce + return reduce(lambda x, y: x | y, values, out) - def map_subscript(self, expr: prim.Subscript) -> Mapping[BindingName, + def map_subscript(self, expr: prim.Subscript) -> dict[BindingName, set[tuple[prim.Variable, ...]]]: """ - Record the indexing usage for the variable if we are tracking - the specific variable. + Record the indexing expression if the Subscript expression has a prim.Variable + as its aggregate. This will record an ignorable variable for each part of the + indexing expression that is not already a prim.Variable. """ if isinstance(expr.aggregate, prim.Variable): name: BindingName = expr.aggregate.name index = tuple(val if isinstance(val, prim.Variable) - else prim.Variable(name="IGNORE") + else prim.Variable(name=IGNORE_VARIABLE_STR) for val in expr.index_tuple) - base: Mapping[BindingName, set[tuple[prim.Variable, ...]]] = {name: + base: dict[BindingName, set[tuple[prim.Variable, ...]]] = {name: {index}} - return self.combine([base, self.rec(expr.index)]) + assert base + return base return {} - def map_algebraic_leaf(self, expr: prim.ExpressionNode) -> Mapping[BindingName, + def map_algebraic_leaf(self, expr: prim.ExpressionNode) -> dict[BindingName, set[tuple[prim.Variable, ...]]]: return {} - def map_constant(self, expr: object) -> Mapping[BindingName, + def map_constant(self, expr: object) -> dict[BindingName, set[tuple[prim.Variable, ...]]]: return {} # }}} @@ -150,18 +166,14 @@ def map_constant(self, expr: object) -> Mapping[BindingName, # {{{ Tag Reduction expressions +@dataclass(init=True, repr=True) class TagReductionAxesMapper(ScalarIdentityMapper[[]]): - def __init__(self, - axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]], - array_expr: IndexLambda): - self.axis_to_tags = axis_to_tags - self.array_expr = array_expr + axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]] + array_expr: IndexLambda def map_subscript(self, expr: prim.Subscript, - #*args: P.args, - #**kwargs: P.kwargs) -> prim.Expression: ) -> prim.Expression: if isinstance(expr.aggregate, prim.Variable): @@ -179,7 +191,7 @@ def map_subscript(self, ) assert isinstance(self.array_expr, IndexLambda) - #return super().map_subscript(expr, *args, **kwargs) + return super().map_subscript(expr) # }}} @@ -321,11 +333,9 @@ def map_index_lambda(self, expr: IndexLambda) -> None: for bnd in expr.bindings.values(): self.rec(bnd) - index_expr_used = IndexExpressionsUsedInIndexLambda()(expr.expr) + index_expr_used = BindingSubscriptsUsedInIndexLambda()(expr.expr) - if __debug__: - out_shape = expr.shape - assert len(out_shape) == expr.ndim + assert len(expr.shape) == expr.ndim for vname, set_of_ind_tuple in index_expr_used.items(): for ind_tuple in set_of_ind_tuple: @@ -334,10 +344,10 @@ def map_index_lambda(self, expr: IndexLambda) -> None: if IDX_LAMBDA_JUST_REDUCTIONS.fullmatch(var_ind_name.name): # Reduction axis. We can ignore it. pass - elif var_ind_name.name[:3] == "_in": + elif BINDING_NAME_RESERVED_PATTERN.fullmatch(var_ind_name.name): # Variable name axis. pass - elif var_ind_name.name == "IGNORE": + elif var_ind_name.name == IGNORE_VARIABLE_STR: # This is not directly represented in output axes. Ignore. pass elif IDX_LAMBDA_INAME.fullmatch(var_ind_name.name): diff --git a/test/test_pytato.py b/test/test_pytato.py index a38e4961f..343c412e3 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1213,6 +1213,20 @@ def test_lower_to_index_lambda(): assert isinstance(binding, Reshape) +def test_reserved_binding_name_patterns(): + from pytato.transform.metadata import BINDING_NAME_RESERVED_PATTERN + + fail_strings = ["_r0", "_r000", "_r01", "_00", "_r101", "_1", "_0", "_101", + "_in", "_in00", "1_in", "_in01"] + pass_strings = ["_in0", "_in1", "_in554", "_in10"] + + for test_str in fail_strings: + assert not BINDING_NAME_RESERVED_PATTERN.fullmatch(test_str) + + for test_str in pass_strings: + assert BINDING_NAME_RESERVED_PATTERN.fullmatch(test_str) + + def test_reserved_scalar_iname_patterns(): from pytato.scalar_expr import ( IDX_LAMBDA_INAME, From 23ffe34bcccb206ea575bdee06b97785bb2c350a Mon Sep 17 00:00:00 2001 From: Nick Koskelo Date: Tue, 14 Jan 2025 00:24:03 +0000 Subject: [PATCH 24/28] Let's get the reduction descriptors at the start when we are recording the equations and propagate that information through. --- pytato/transform/metadata.py | 111 +++++++++++++++++++---------------- test/test_pytato.py | 42 +++++++++++++ 2 files changed, 101 insertions(+), 52 deletions(-) diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 4eb02b38a..3f5bd770b 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -77,7 +77,6 @@ from pytato.function import NamedCallResult from pytato.scalar_expr import ( IDX_LAMBDA_INAME, - IDX_LAMBDA_JUST_REDUCTIONS, CombineMapper, IdentityMapper as ScalarIdentityMapper, ) @@ -100,14 +99,14 @@ BindingName: TypeAlias = str P = ParamSpec("P") -IGNORE_VARIABLE_STR: str = "IGNORE" BINDING_NAME_RESERVED_PATTERN = re.compile(r"^(_in?(0|([1-9][0-9]*)))$") -# {{{ BindingSubscriptsUsedInIndexLambda +# {{{ BindingSubscriptsCollector -class BindingSubscriptsUsedInIndexLambda(CombineMapper[dict[BindingName, - set[tuple[prim.Variable, ...]]], + +class BindingSubscriptsCollector(CombineMapper[dict[BindingName, + set[tuple[prim.Expression, ...]]], []]): """ Return all the subscript expressions used by a variable specified by BindingName. @@ -127,39 +126,37 @@ class BindingSubscriptsUsedInIndexLambda(CombineMapper[dict[BindingName, """ def combine(self, values: Iterable[dict[BindingName, - set[tuple[prim.Variable, ...]]]]) \ - -> dict[BindingName, set[tuple[prim.Variable, ...]]]: - out: dict[BindingName, set[tuple[prim.Variable, ...]]] = {} + set[tuple[prim.Expression, ...]]]]) \ + -> dict[BindingName, set[tuple[prim.Expression, ...]]]: + out: dict[BindingName, set[tuple[prim.Expression, ...]]] = {} from functools import reduce return reduce(lambda x, y: x | y, values, out) def map_subscript(self, expr: prim.Subscript) -> dict[BindingName, - set[tuple[prim.Variable, ...]]]: + set[tuple[prim.Expression, ...]]]: """ - Record the indexing expression if the Subscript expression has a prim.Variable + Record the indexing expression if the Subscript expression has a prim.Expression as its aggregate. This will record an ignorable variable for each part of the - indexing expression that is not already a prim.Variable. + indexing expression that is not already a prim.Expression. """ if isinstance(expr.aggregate, prim.Variable): name: BindingName = expr.aggregate.name - index = tuple(val if isinstance(val, prim.Variable) - else prim.Variable(name=IGNORE_VARIABLE_STR) - for val in expr.index_tuple) - base: dict[BindingName, set[tuple[prim.Variable, ...]]] = {name: - {index}} - assert base + base: dict[BindingName, + set[tuple[prim.Expression, ...]]] = {name: {expr.index_tuple}} + index = self.rec(expr.index) + breakpoint() return base return {} - def map_algebraic_leaf(self, expr: prim.ExpressionNode) -> dict[BindingName, - set[tuple[prim.Variable, ...]]]: + def map_algebraic_leaf(self, expr: prim.Expression) -> dict[BindingName, + set[tuple[prim.Expression, ...]]]: return {} def map_constant(self, expr: object) -> dict[BindingName, - set[tuple[prim.Variable, ...]]]: + set[tuple[prim.Expression, ...]]]: return {} # }}} @@ -171,6 +168,7 @@ class TagReductionAxesMapper(ScalarIdentityMapper[[]]): axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]] array_expr: IndexLambda + orig_array_expr: IndexLambda def map_subscript(self, expr: prim.Subscript, @@ -185,9 +183,11 @@ def map_subscript(self, name in self.array_expr.bindings.keys(): # We matched the reduction axis. # Now we need to add the tag to the original expression. - my_key = (self.array_expr.bindings[name], iaxis) + redn_name = val.name + breakpoint() + my_key = (self.orig_array_expr.bindings[name], iaxis) self.array_expr = self.array_expr.with_tagged_reduction( - name, self.axis_to_tags.get(my_key, []) + redn_name, self.axis_to_tags.get(my_key, []) ) assert isinstance(self.array_expr, IndexLambda) @@ -253,16 +253,17 @@ def __init__(self, tag_t: type[Tag]) -> None: # axis_to_var: mapping from (array, iaxis) to the variable to be # used for unification. - self.axis_to_var: bidict[tuple[Array, int], str] = bidict() + self.axis_to_var: bidict[tuple[Array, int | str], str] = bidict() self.known_tag_to_var: dict[Tag, str] = {} self.equations: list[tuple[str, str]] = [] + self.reduction_equations: list[tuple[str, str]] = [] # }}} # {{{ unification helpers - def get_var_for_axis(self, ary: Array, iaxis: int) -> str: + def get_var_for_axis(self, ary: Array, iaxis: int | str) -> str: key = (ary, iaxis) try: @@ -287,6 +288,13 @@ def record_equation(self, lhs: str, rhs: str) -> None: :attr:`equations`. """ self.equations.append((lhs, rhs)) + + def record_reduction_equation(self, lhs: str, rhs: str) -> None: + r""" + Adds the equation :math:`\{\text{lhs}\doteq\text{rhs}}` to + :attr:`equations`. + """ + self.reduction_equations.append((lhs, rhs)) def record_equations_from_axes_tags(self, ary: Array) -> None: """ @@ -324,40 +332,36 @@ def _map_input_base(self, expr: InputArgumentBase) -> None: def map_index_lambda(self, expr: IndexLambda) -> None: """ - The propagation semantics for a :class:`~pytato.IndexLambda` are - implemented only for operations that can be raised to a - :class:`~pytato.raising.HighLevelOp`. In such cases, an equality - equation is recorded for every non-broadcasted axis of an operand and - its corresponding axis of *expr*. + Equality conditions are added between an axis of the operands which is indexed + by a :class:`~pymbolic.Variable` which has a name that follows the reserved + iname format, "_[0-9]+", and the axis of the output specified by the iname. """ for bnd in expr.bindings.values(): self.rec(bnd) - index_expr_used = BindingSubscriptsUsedInIndexLambda()(expr.expr) + index_expr_used = BindingSubscriptsCollector()(expr.expr) assert len(expr.shape) == expr.ndim for vname, set_of_ind_tuple in index_expr_used.items(): for ind_tuple in set_of_ind_tuple: - for axis_ind in range(len(ind_tuple)): - var_ind_name = ind_tuple[axis_ind] - if IDX_LAMBDA_JUST_REDUCTIONS.fullmatch(var_ind_name.name): - # Reduction axis. We can ignore it. - pass - elif BINDING_NAME_RESERVED_PATTERN.fullmatch(var_ind_name.name): - # Variable name axis. - pass - elif var_ind_name.name == IGNORE_VARIABLE_STR: - # This is not directly represented in output axes. Ignore. - pass - elif IDX_LAMBDA_INAME.fullmatch(var_ind_name.name): - # matched with an iname. - inum = int(var_ind_name.name[1:]) - lhs: str = self.get_var_for_axis(expr.bindings[vname], axis_ind) - rhs: str = self.get_var_for_axis(expr, inum) - self.record_equation(lhs, rhs) - else: - raise ValueError(f"Unknown index name used in {vname}") + for axis_ind, var_ind_name in enumerate(ind_tuple): + if isinstance(var_ind_name, prim.Variable): + if IDX_LAMBDA_INAME.fullmatch(var_ind_name.name): + # matched with an iname. + inum = int(var_ind_name.name[1:]) + lhs: str = self.get_var_for_axis(expr.bindings[vname], + axis_ind) + rhs: str = self.get_var_for_axis(expr, inum) + self.record_equation(lhs, rhs) + elif var_ind_name.name in expr.var_to_reduction_descr.keys():\ + # matched with a reduction iname. + breakpoint() + lhs: str = self.get_var_for_axis(expr.bindings[vname], + axis_ind) + rhs: str = self.get_var_for_axis(expr, var_ind_name.name) + self.record_reduction_equation(lhs, rhs) + return def map_stack(self, expr: Stack) -> None: @@ -695,7 +699,8 @@ def rec(self, expr: ArrayOrNames) -> Any: self.axis_to_tags.get((arg, iaxis), []) ) - if isinstance(expr_copy, IndexLambda): + if isinstance(expr, IndexLambda): + assert isinstance(expr_copy, IndexLambda) if expr_copy.var_to_reduction_descr: # This is a reduction operation. # We need to find the axes that are reduced over @@ -703,7 +708,8 @@ def rec(self, expr: ArrayOrNames) -> Any: mymapper: TagReductionAxesMapper = \ TagReductionAxesMapper( self.axis_to_tags, - expr_copy + expr_copy, + expr ) mymapper(expr_copy.expr) # Tag the axes expr_copy = mymapper.array_expr # Recover it. @@ -775,7 +781,7 @@ def unify_axes_tags( axis_to_solved_tags: dict[tuple[Array, int], set[Tag]] = {} propagation_graph = undirected_graph_from_edges( - equations_collector.equations + equations_collector.equations + equations_collector.reduction_equations ) ignored_vars = set({ @@ -797,6 +803,7 @@ def unify_axes_tags( set() ).add(tag) + breakpoint() return AxisTagAttacher(axis_to_solved_tags, tag_corresponding_redn_descr=unify_redn_descrs, )(expr) diff --git a/test/test_pytato.py b/test/test_pytato.py index 343c412e3..a287f416a 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1403,6 +1403,48 @@ def test_unify_axes_tags_with_unbroadcastable_expressions(): assert (term.axes[1].tags_of_type(TestlibTag) == frozenset([QuuxTag()])) + # Side-by-Side reduction. + + # a[_0] + sum(_r0, b[_0, _r0] + sum(_r1, b[_0,_r1])) + a = pt.make_placeholder("a", (512)) + b = pt.make_placeholder("b", (512, 10)) + c = pt.make_placeholder("c", (512, 10)) + c = c.with_tagged_axis(1, QuuxTag()) + b = b.with_tagged_axis(1, FooTag()) + a = a.with_tagged_axis(0, BazTag()) + + x = prim.Subscript(prim.Variable("_in0"), (prim.Variable("_0"))) + y = prim.Subscript(prim.Variable("_in1"), + (prim.Variable("_0"), prim.Variable("_r0"))) + z = prim.Subscript(prim.Variable("_in2"), (prim.Variable("_0"), + prim.Variable("_r1"))) + + w = pt.IndexLambda(expr=pt.scalar_expr.Reduce(prim.Sum((x,y,z)), + pt.reductions.SumReductionOperation, + immutabledict({"_r0": (0,10), + "_r1": (0,10)})), + bindings=immutabledict({"_in0": a, "_in1": b, "_in2": c}), + shape=(512,), tags=pt.array._get_default_tags(), + axes=pt.array._get_default_axes(1), + dtype=float, + var_to_reduction_descr=immutabledict({"_r0": pt.array.ReductionDescriptor( + frozenset([]) + ), + "_r1": pt.array.ReductionDescriptor( + frozenset([]) + )})) + + w_unified = pt.unify_axes_tags(w) + + assert w_unified.var_to_reduction_descr["_r0"].tags_of_type(TestlibTag) == frozenset([FooTag()]) + assert w_unified.var_to_reduction_descr["_r1"].tags_of_type(TestlibTag) == frozenset([QuuxTag()]) + + assert w_unified.axes[0].tags_of_type(TestlibTag) == frozenset([BazTag()]) + for key in ["_in" + str(i) for i in range(2)]: + assert w_unified.bindings[key].axes[0].tags_of_type(TestlibTag) == frozenset([BazTag()]) + + + def test_ignoring_axes_during_propagation(): from pytools.tag import UniqueTag From 9ac0a7d03f62e1d71d67f52f79480ba53ba0d52a Mon Sep 17 00:00:00 2001 From: Nick Koskelo Date: Tue, 14 Jan 2025 22:46:10 +0000 Subject: [PATCH 25/28] Update the record equations keys to be [Array, int | str] so that we can store the reduction string as well. --- pytato/transform/metadata.py | 110 +++++++------------------- test/test_pytato.py | 144 +++++++++++++++++++++++++---------- 2 files changed, 128 insertions(+), 126 deletions(-) diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 3f5bd770b..3fa275708 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -40,7 +40,6 @@ import logging import re -from dataclasses import dataclass from typing import ( TYPE_CHECKING, Any, @@ -78,7 +77,6 @@ from pytato.scalar_expr import ( IDX_LAMBDA_INAME, CombineMapper, - IdentityMapper as ScalarIdentityMapper, ) from pytato.transform import ArrayOrNames, CopyMapper, Mapper from pytato.utils import are_shape_components_equal, are_shapes_equal @@ -110,19 +108,8 @@ class BindingSubscriptsCollector(CombineMapper[dict[BindingName, []]): """ Return all the subscript expressions used by a variable specified by BindingName. - Ex: _in1[_0,_1] would result in an dictionary entry {"_in1": ("_0", "_1")}. - - In the case that a subscript expression is not a variable, like in - - Ex: - _in1[_0, 0] - - that subscript will be replaced with a `Variable` with the name IGNORE. - - So the second example would result in an dictionary entry - {"_in1": ("_0","IGNORE")}. """ def combine(self, values: Iterable[dict[BindingName, @@ -135,19 +122,12 @@ def combine(self, def map_subscript(self, expr: prim.Subscript) -> dict[BindingName, set[tuple[prim.Expression, ...]]]: """ - Record the indexing expression if the Subscript expression has a prim.Expression - as its aggregate. This will record an ignorable variable for each part of the - indexing expression that is not already a prim.Expression. + Record the indexing expression if the Subscript expression has a prim.Variable + as its aggregate. """ if isinstance(expr.aggregate, prim.Variable): - name: BindingName = expr.aggregate.name - - base: dict[BindingName, - set[tuple[prim.Expression, ...]]] = {name: {expr.index_tuple}} - index = self.rec(expr.index) - breakpoint() - return base + return {expr.aggregate.name: {expr.index_tuple}} return {} def map_algebraic_leaf(self, expr: prim.Expression) -> dict[BindingName, @@ -160,44 +140,9 @@ def map_constant(self, expr: object) -> dict[BindingName, return {} # }}} -# {{{ Tag Reduction expressions - - -@dataclass(init=True, repr=True) -class TagReductionAxesMapper(ScalarIdentityMapper[[]]): - - axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]] - array_expr: IndexLambda - orig_array_expr: IndexLambda - - def map_subscript(self, - expr: prim.Subscript, - ) -> prim.Expression: - - if isinstance(expr.aggregate, prim.Variable): - name: BindingName = expr.aggregate.name - - for iaxis, val in enumerate(expr.index_tuple): - if isinstance(val, prim.Variable): - if val.name in self.array_expr.var_to_reduction_descr.keys() and \ - name in self.array_expr.bindings.keys(): - # We matched the reduction axis. - # Now we need to add the tag to the original expression. - redn_name = val.name - breakpoint() - my_key = (self.orig_array_expr.bindings[name], iaxis) - self.array_expr = self.array_expr.with_tagged_reduction( - redn_name, self.axis_to_tags.get(my_key, []) - ) - - assert isinstance(self.array_expr, IndexLambda) - - return super().map_subscript(expr) -# }}} - - # {{{ AxesTagsEquationCollector + class AxesTagsEquationCollector(Mapper[None, []]): r""" Records equations arising from operand/output axes equivalence for an array @@ -258,7 +203,6 @@ def __init__(self, tag_t: type[Tag]) -> None: self.equations: list[tuple[str, str]] = [] - self.reduction_equations: list[tuple[str, str]] = [] # }}} # {{{ unification helpers @@ -288,13 +232,6 @@ def record_equation(self, lhs: str, rhs: str) -> None: :attr:`equations`. """ self.equations.append((lhs, rhs)) - - def record_reduction_equation(self, lhs: str, rhs: str) -> None: - r""" - Adds the equation :math:`\{\text{lhs}\doteq\text{rhs}}` to - :attr:`equations`. - """ - self.reduction_equations.append((lhs, rhs)) def record_equations_from_axes_tags(self, ary: Array) -> None: """ @@ -347,20 +284,17 @@ def map_index_lambda(self, expr: IndexLambda) -> None: for ind_tuple in set_of_ind_tuple: for axis_ind, var_ind_name in enumerate(ind_tuple): if isinstance(var_ind_name, prim.Variable): + lhs: str = self.get_var_for_axis(expr.bindings[vname], + axis_ind) if IDX_LAMBDA_INAME.fullmatch(var_ind_name.name): # matched with an iname. inum = int(var_ind_name.name[1:]) - lhs: str = self.get_var_for_axis(expr.bindings[vname], - axis_ind) rhs: str = self.get_var_for_axis(expr, inum) self.record_equation(lhs, rhs) - elif var_ind_name.name in expr.var_to_reduction_descr.keys():\ + elif var_ind_name.name in expr.var_to_reduction_descr.keys(): # matched with a reduction iname. - breakpoint() - lhs: str = self.get_var_for_axis(expr.bindings[vname], - axis_ind) - rhs: str = self.get_var_for_axis(expr, var_ind_name.name) - self.record_reduction_equation(lhs, rhs) + rhs = self.get_var_for_axis(expr, var_ind_name.name) + self.record_equation(lhs, rhs) return @@ -661,10 +595,11 @@ class AxisTagAttacher(CopyMapper): A mapper that tags the axes in a DAG as prescribed by *axis_to_tags*. """ def __init__(self, - axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]], + axis_to_tags: Mapping[tuple[Array, int | str], Collection[Tag]], tag_corresponding_redn_descr: bool): super().__init__() - self.axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]] = axis_to_tags + self.axis_to_tags: Mapping[tuple[Array, int | str], + Collection[Tag]] = axis_to_tags self.tag_corresponding_redn_descr: bool = tag_corresponding_redn_descr def rec(self, expr: ArrayOrNames) -> Any: @@ -705,6 +640,13 @@ def rec(self, expr: ArrayOrNames) -> Any: # This is a reduction operation. # We need to find the axes that are reduced over # and update the tag/tag them appropriately. + for redn_var in expr.var_to_reduction_descr.keys(): + expr_copy = expr_copy.with_tagged_reduction( + redn_var, + self.axis_to_tags.get((expr, redn_var), []) + ) + + """ mymapper: TagReductionAxesMapper = \ TagReductionAxesMapper( self.axis_to_tags, @@ -713,6 +655,7 @@ def rec(self, expr: ArrayOrNames) -> Any: ) mymapper(expr_copy.expr) # Tag the axes expr_copy = mymapper.array_expr # Recover it. + """ # }}} @@ -778,10 +721,10 @@ def unify_axes_tags( ) known_tag_vars = frozenset(equations_collector.known_tag_to_var.values()) - axis_to_solved_tags: dict[tuple[Array, int], set[Tag]] = {} + axis_to_solved_tags: dict[tuple[Array, int | str], set[Tag]] = {} propagation_graph = undirected_graph_from_edges( - equations_collector.equations + equations_collector.reduction_equations + equations_collector.equations ) ignored_vars = set({ @@ -789,10 +732,10 @@ def unify_axes_tags( if isinstance(tag, AxisIgnoredForPropagationTag) }) - ignored_vars.update({ - ax_var for (ary, ax), ax_var in equations_collector.axis_to_var.items() - if ary.axes[ax].tags_of_type(AxisIgnoredForPropagationTag) - }) + for (ary, ax), ax_var in equations_collector.axis_to_var.items(): + if isinstance(ax, int): + if ary.axes[ax].tags_of_type(AxisIgnoredForPropagationTag): + ignored_vars.update({ax_var}) for tag, var in equations_collector.known_tag_to_var.items(): reachable_nodes = get_reachable_nodes(propagation_graph, var, @@ -803,7 +746,6 @@ def unify_axes_tags( set() ).add(tag) - breakpoint() return AxisTagAttacher(axis_to_solved_tags, tag_corresponding_redn_descr=unify_redn_descrs, )(expr) diff --git a/test/test_pytato.py b/test/test_pytato.py index a287f416a..498a4cd17 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1365,6 +1365,108 @@ def test_unify_axes_tags(): # }}} + # {{ Reduction Operations with IndexLambda + # {{{ Reduce on outside of scalar expression + + from immutabledict import immutabledict + + import pymbolic.primitives as prim + + def setup_test(): + a = pt.make_placeholder("a", (512)) + b = pt.make_placeholder("b", (512, 10)) + c = pt.make_placeholder("c", (512, 10)) + c = c.with_tagged_axis(1, QuuxTag()) + b = b.with_tagged_axis(1, FooTag()) + a = a.with_tagged_axis(0, BazTag()) + + x = prim.Subscript(prim.Variable("_in0"), (prim.Variable("_0"))) + y = prim.Subscript(prim.Variable("_in1"), + (prim.Variable("_0"), prim.Variable("_r0"))) + z = prim.Subscript(prim.Variable("_in2"), (prim.Variable("_0"), + prim.Variable("_r1"))) + + return a, b, c, x, y, z + + def assert_tags_were_propagated_appropriately(arr): + assert arr.var_to_reduction_descr["_r0"].tags_of_type(TestlibTag) == \ + frozenset([FooTag()]) + assert arr.var_to_reduction_descr["_r1"].tags_of_type(TestlibTag) == \ + frozenset([QuuxTag()]) + + assert arr.axes[0].tags_of_type(TestlibTag) == frozenset([BazTag()]) + for key in ["_in" + str(i) for i in range(2)]: + assert arr.bindings[key].axes[0].tags_of_type(TestlibTag) == \ + frozenset([BazTag()]) + + def get_def_reduction_descrs(): + return immutabledict({"_r0": pt.array.ReductionDescriptor(frozenset([])), + "_r1": pt.array.ReductionDescriptor(frozenset([])) + }) + + a, b, c, x, y, z = setup_test() + # sum((_r0, _r1), a[_0] + b[_0, _r0] + b[_0,_r1])) + w = pt.IndexLambda(expr=pt.scalar_expr.Reduce(prim.Sum((x, y, z)), + pt.reductions.SumReductionOperation, + immutabledict({"_r0": (0, 10), + "_r1": (0, 10)})), + bindings=immutabledict({"_in0": a, "_in1": b, "_in2": c}), + shape=(512,), tags=pt.array._get_default_tags(), + axes=pt.array._get_default_axes(1), + dtype=float, + var_to_reduction_descr=get_def_reduction_descrs()) + + w_unified = pt.unify_axes_tags(w) + + assert_tags_were_propagated_appropriately(w_unified) + + # }}} Reduction on the outside of the scalar expression. + + # {{{ Side-by-Side reduction. + + a, b, c, x, y, z = setup_test() + + # a[_0] + sum(_r0, b[_0, _r0]) + sum(_r1, b[_0,_r1]) + + w = pt.IndexLambda(expr=prim.Sum((x, pt.scalar_expr.Reduce(y, + pt.reductions.SumReductionOperation, + immutabledict({"_r0": (0, 10)})), + pt.scalar_expr.Reduce(z, + pt.reductions.SumReductionOperation, + immutabledict({"_r1": (0, 10)})))), + bindings=immutabledict({"_in0": a, "_in1": b, "_in2": c}), + shape=(512,), tags=pt.array._get_default_tags(), + axes=pt.array._get_default_axes(1), + dtype=float, + var_to_reduction_descr=get_def_reduction_descrs()) + + w_unified = pt.unify_axes_tags(w) + assert_tags_were_propagated_appropriately(w_unified) + + # }}} + + # {{{ Nested Reductions. + # a[_0] + sum(_r0, b[_0, _r0] + sum(_r1, b[_0,_r1])) + a, b, c, x, y, z = setup_test() + + w = pt.IndexLambda(expr=prim.Sum((x, pt.scalar_expr.Reduce(prim.Sum((y, + pt.scalar_expr.Reduce(z, + pt.reductions.SumReductionOperation, + immutabledict({"_r1": (0, 10)})))), + pt.reductions.SumReductionOperation, + immutabledict({"_r0": (0, 10)})))), + bindings=immutabledict({"_in0": a, "_in1": b, "_in2": c}), + shape=(512,), tags=pt.array._get_default_tags(), + axes=pt.array._get_default_axes(1), + dtype=float, + var_to_reduction_descr=get_def_reduction_descrs()) + + w_unified = pt.unify_axes_tags(w) + assert_tags_were_propagated_appropriately(w_unified) + + # }}} + # }} + def test_unify_axes_tags_with_unbroadcastable_expressions(): @@ -1403,48 +1505,6 @@ def test_unify_axes_tags_with_unbroadcastable_expressions(): assert (term.axes[1].tags_of_type(TestlibTag) == frozenset([QuuxTag()])) - # Side-by-Side reduction. - - # a[_0] + sum(_r0, b[_0, _r0] + sum(_r1, b[_0,_r1])) - a = pt.make_placeholder("a", (512)) - b = pt.make_placeholder("b", (512, 10)) - c = pt.make_placeholder("c", (512, 10)) - c = c.with_tagged_axis(1, QuuxTag()) - b = b.with_tagged_axis(1, FooTag()) - a = a.with_tagged_axis(0, BazTag()) - - x = prim.Subscript(prim.Variable("_in0"), (prim.Variable("_0"))) - y = prim.Subscript(prim.Variable("_in1"), - (prim.Variable("_0"), prim.Variable("_r0"))) - z = prim.Subscript(prim.Variable("_in2"), (prim.Variable("_0"), - prim.Variable("_r1"))) - - w = pt.IndexLambda(expr=pt.scalar_expr.Reduce(prim.Sum((x,y,z)), - pt.reductions.SumReductionOperation, - immutabledict({"_r0": (0,10), - "_r1": (0,10)})), - bindings=immutabledict({"_in0": a, "_in1": b, "_in2": c}), - shape=(512,), tags=pt.array._get_default_tags(), - axes=pt.array._get_default_axes(1), - dtype=float, - var_to_reduction_descr=immutabledict({"_r0": pt.array.ReductionDescriptor( - frozenset([]) - ), - "_r1": pt.array.ReductionDescriptor( - frozenset([]) - )})) - - w_unified = pt.unify_axes_tags(w) - - assert w_unified.var_to_reduction_descr["_r0"].tags_of_type(TestlibTag) == frozenset([FooTag()]) - assert w_unified.var_to_reduction_descr["_r1"].tags_of_type(TestlibTag) == frozenset([QuuxTag()]) - - assert w_unified.axes[0].tags_of_type(TestlibTag) == frozenset([BazTag()]) - for key in ["_in" + str(i) for i in range(2)]: - assert w_unified.bindings[key].axes[0].tags_of_type(TestlibTag) == frozenset([BazTag()]) - - - def test_ignoring_axes_during_propagation(): from pytools.tag import UniqueTag From 5ccc11fde9885f3be0d0f7a5c8bf56766405437e Mon Sep 17 00:00:00 2001 From: Nick Koskelo Date: Tue, 14 Jan 2025 23:07:56 +0000 Subject: [PATCH 26/28] Remove unused code. --- pytato/transform/metadata.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 3fa275708..f1f5c79ea 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -645,18 +645,6 @@ def rec(self, expr: ArrayOrNames) -> Any: redn_var, self.axis_to_tags.get((expr, redn_var), []) ) - - """ - mymapper: TagReductionAxesMapper = \ - TagReductionAxesMapper( - self.axis_to_tags, - expr_copy, - expr - ) - mymapper(expr_copy.expr) # Tag the axes - expr_copy = mymapper.array_expr # Recover it. - """ - # }}} self._cache[key] = expr_copy From b522193b0753f8d1fe9672b0ac051428c221ab8b Mon Sep 17 00:00:00 2001 From: Nick Date: Fri, 24 Jan 2025 02:22:39 -0600 Subject: [PATCH 27/28] Save off changes moving down to just handling IndexLambdas and those which cannot be transformed into one. --- pytato/scalar_expr.py | 5 +- pytato/transform/metadata.py | 95 ++++++++++++++++++++++++++++++------ test/test_pytato.py | 54 +++++++++++++++----- 3 files changed, 125 insertions(+), 29 deletions(-) diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index 7861e3504..efabbdc33 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -163,9 +163,8 @@ def map_reduce(self, expr: Reduce) -> ScalarExpression: for name, bound in expr.bounds.items()})) -IDX_LAMBDA_RE = re.compile(r"^(_r?(0|([1-9][0-9]*)))$") -IDX_LAMBDA_INAME = re.compile(r"^(_(0|([1-9][0-9]*)))$") -IDX_LAMBDA_JUST_REDUCTIONS = re.compile(r"^(_r(0|([1-9][0-9]*)))$") +IDX_LAMBDA_RE = re.compile(r"^(_r?(?P0|[1-9][0-9]*))$") +IDX_LAMBDA_RESERVED_INDEX_PATTERN = re.compile(r"^(_(?P0|[1-9][0-9]*))$") class DependencyMapper(DependencyMapperBase[P]): diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index f1f5c79ea..65c6fd147 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -75,7 +75,7 @@ from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder from pytato.function import NamedCallResult from pytato.scalar_expr import ( - IDX_LAMBDA_INAME, + IDX_LAMBDA_RESERVED_INDEX_PATTERN, CombineMapper, ) from pytato.transform import ArrayOrNames, CopyMapper, Mapper @@ -116,8 +116,9 @@ def combine(self, set[tuple[prim.Expression, ...]]]]) \ -> dict[BindingName, set[tuple[prim.Expression, ...]]]: out: dict[BindingName, set[tuple[prim.Expression, ...]]] = {} + import operator from functools import reduce - return reduce(lambda x, y: x | y, values, out) + return reduce(operator.or_, values, out) def map_subscript(self, expr: prim.Subscript) -> dict[BindingName, set[tuple[prim.Expression, ...]]]: @@ -127,7 +128,22 @@ def map_subscript(self, expr: prim.Subscript) -> dict[BindingName, """ if isinstance(expr.aggregate, prim.Variable): - return {expr.aggregate.name: {expr.index_tuple}} + name = expr.aggregate.name + base: dict[BindingName, + set[tuple[prim.Expression, ...]]] = {name: {expr.index_tuple}} + + """ + for ind, subexpr in enumerate(expr.index_tuple): + sub = self.rec(subexpr) + if sub: + # we have nested subscripts. + for key, val in sub.items(): + # The new key will be comma separated. + newkey = name + "," + str(ind) + "," + key + base.update({newkey: val}) + """ + return self.combine([base] + [self.rec(subexpr) for _, subexpr in enumerate(expr.index_tuple)]) + #return {expr.aggregate.name: {expr.index_tuple}} return {} def map_algebraic_leaf(self, expr: prim.Expression) -> dict[BindingName, @@ -278,23 +294,36 @@ def map_index_lambda(self, expr: IndexLambda) -> None: index_expr_used = BindingSubscriptsCollector()(expr.expr) - assert len(expr.shape) == expr.ndim for vname, set_of_ind_tuple in index_expr_used.items(): for ind_tuple in set_of_ind_tuple: for axis_ind, var_ind_name in enumerate(ind_tuple): if isinstance(var_ind_name, prim.Variable): lhs: str = self.get_var_for_axis(expr.bindings[vname], - axis_ind) - if IDX_LAMBDA_INAME.fullmatch(var_ind_name.name): - # matched with an iname. - inum = int(var_ind_name.name[1:]) - rhs: str = self.get_var_for_axis(expr, inum) - self.record_equation(lhs, rhs) + axis_ind) + matched_pattern = IDX_LAMBDA_RESERVED_INDEX_PATTERN.fullmatch(var_ind_name.name) + if matched_pattern: + # matched with an axis index. + self.record_equation(lhs, self.get_var_for_axis(expr, + int(matched_pattern.group("index")))) elif var_ind_name.name in expr.var_to_reduction_descr.keys(): - # matched with a reduction iname. - rhs = self.get_var_for_axis(expr, var_ind_name.name) - self.record_equation(lhs, rhs) + # matched with a reduction axis. + # We are assuming that this axis is eliminated from the + # axes of the output array. So, the metadata will only be keep + # in the reduction descriptor object which is indexed by the + # var_ind_name.name + self.record_equation(lhs, + self.get_var_for_axis(expr, var_ind_name.name)) + + elif BINDING_NAME_RESERVED_PATTERN.fullmatch(var_ind_name.name): + # This means that we had an index of index. + # So, the metadata propagation with this index is data + # dependent. + pass + else: + pass + #warning("Variable does not match an index pattern. It will + #be ignored for metadata propagation.") return @@ -304,6 +333,7 @@ def map_stack(self, expr: Stack) -> None: and their corresponding axis in *expr*. No equation is added for the newly created axis i.e. :attr:`pytato.array.Stack.axis`. """ + raise NotImplementedError for ary in expr.arrays: self.rec(ary) @@ -331,6 +361,7 @@ def map_concatenate(self, expr: Concatenate) -> None: added for the concatenated axis i.e. :attr:`pytato.array.Concatenate.axis`. """ + raise NotImplementedError for ary in expr.arrays: self.rec(ary) @@ -351,6 +382,7 @@ def map_axis_permutation(self, expr: AxisPermutation its corresponding axis in *expr* as specified by :attr:`pytato.array.AxisPermutation.axis_permutation`. """ + raise NotImplementedError self.rec(expr.array) assert expr.ndim == expr.array.ndim @@ -369,6 +401,7 @@ def map_basic_index(self, expr: BasicIndex) -> None: sliced axis is one which goes along the entire length of the axis with a positive unit stride. """ + raise NotImplementedError self.rec(expr.array) i_out_axis = 0 @@ -401,6 +434,7 @@ def map_contiguous_advanced_index(self, indices adds an equality equation for each non-broadcasted axis of an indexing array to its corresponding axis in *expr*. """ + raise NotImplementedError from pytato.utils import get_shape_after_broadcasting, partition self.rec(expr.array) @@ -488,6 +522,7 @@ def map_reshape(self, expr: Reshape) -> None: output and so no constraints are enforced except when the :class:`pytato.Reshape` has come from a :func:`pytato.expand_dims`. """ + raise NotImplementedError from pytato.tags import ExpandedDimsReshape self.rec(expr.array) @@ -514,6 +549,7 @@ def map_einsum(self, expr: Einsum) -> None: :func:`pytato.einsum` thereby having the same the :class:`~pytato.array.EinsumAxisDescriptor`. """ + raise NotImplementedError from pytato.array import EinsumAxisDescriptor, EinsumElementwiseAxis for arg in expr.args: @@ -535,6 +571,7 @@ def map_einsum(self, expr: Einsum) -> None: descr_to_var[descr] = in_tag_var def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> None: + raise NotImplementedError for _, subexpr in sorted(expr._data.items()): self.rec(subexpr) @@ -542,6 +579,7 @@ def map_loopy_call(self, expr: LoopyCall) -> None: """ Does not add any equations. """ + raise NotImplementedError for _, subexpr in sorted(expr.bindings.items()): if isinstance(subexpr, Array): self.rec(subexpr) @@ -551,6 +589,7 @@ def map_loopy_call(self, expr: LoopyCall) -> None: # high level ops, but that's quite involved and probably not worth it. def map_named_array(self, expr: NamedArray) -> None: + raise NotImplementedError self.rec(expr._container) def map_distributed_send_ref_holder(self, @@ -563,6 +602,7 @@ def map_distributed_send_ref_holder(self, equations are added between each axis of *expr* and its corresponding axis in the pass-through data. """ + raise NotImplementedError self.rec(expr.passthrough_data) self.rec(expr.send.data) for idim in range(expr.ndim): @@ -577,6 +617,7 @@ def map_distributed_recv(self, :class:`pytato.DistributedRecv` does not have any operands and so no more equations are deduced. """ + raise NotImplementedError def map_named_call_result(self, expr: NamedCallResult) -> Array: raise NotImplementedError( @@ -696,7 +737,26 @@ def unify_axes_tags( """ equations_collector = equations_collector_t(tag_t) - equations_collector(expr) + # First we will convert the expression to a series of IndexLambda operations. + + from pytato.transform.lower_to_index_lambda import ToIndexLambdaMixin, to_index_lambda + from pytato.transform import TransformMapper + from pytato.diagnostic import CannotBeLoweredToIndexLambda + mapped_expr = to_index_lambda(expr) + + class MyIndexMapper(TransformMapper, ToIndexLambdaMixin): + def handle_unsupported_array(self, expr: Any) -> Any: + raise CannotBeLoweredToIndexLambda(type(expr)) + + def map_placeholder(self, expr: Placeholder) -> Placeholder: + return expr + + + mymapper = MyIndexMapper() + mapped_expr = mymapper(expr) + + breakpoint() + equations_collector(mapped_expr) # start BFS traversal with the known tags as the sources. # From the equations build a Propagation Graph @@ -709,6 +769,9 @@ def unify_axes_tags( ) known_tag_vars = frozenset(equations_collector.known_tag_to_var.values()) + + # Reduction axes are specified by a str but all other axes are specified + # by an integer. Note that the axes are still uniquely identified. axis_to_solved_tags: dict[tuple[Array, int | str], set[Tag]] = {} propagation_graph = undirected_graph_from_edges( @@ -721,6 +784,8 @@ def unify_axes_tags( }) for (ary, ax), ax_var in equations_collector.axis_to_var.items(): + # Reduction axes do not follow AxisIgnoredForPropagation. + # They cannot propagate the information to descendant of the array anyway. if isinstance(ax, int): if ary.axes[ax].tags_of_type(AxisIgnoredForPropagationTag): ignored_vars.update({ax_var}) @@ -736,6 +801,6 @@ def unify_axes_tags( return AxisTagAttacher(axis_to_solved_tags, tag_corresponding_redn_descr=unify_redn_descrs, - )(expr) + )(mapped_expr) # vim: fdm=marker diff --git a/test/test_pytato.py b/test/test_pytato.py index 498a4cd17..ec553c342 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1229,30 +1229,29 @@ def test_reserved_binding_name_patterns(): def test_reserved_scalar_iname_patterns(): from pytato.scalar_expr import ( - IDX_LAMBDA_INAME, - IDX_LAMBDA_JUST_REDUCTIONS, + IDX_LAMBDA_RESERVED_INDEX_PATTERN, IDX_LAMBDA_RE, ) - test_strings = ["_r0", "_r000", "_r01", "_00", "_r101", "_1", "_0", "_101"] + test_strings = ["_r0", "_r000", "_r01", "_00", "_r101", "_1", "_0", "_101", "_r"] assert IDX_LAMBDA_RE.fullmatch(test_strings[0]) - assert not IDX_LAMBDA_INAME.fullmatch(test_strings[0]) - assert IDX_LAMBDA_JUST_REDUCTIONS.fullmatch(test_strings[0]) + assert not IDX_LAMBDA_RESERVED_INDEX_PATTERN.fullmatch(test_strings[0]) - for pat in [IDX_LAMBDA_INAME, IDX_LAMBDA_RE, IDX_LAMBDA_JUST_REDUCTIONS]: + for pat in [IDX_LAMBDA_RESERVED_INDEX_PATTERN, IDX_LAMBDA_RE]: assert not pat.fullmatch(test_strings[1]) assert not pat.fullmatch(test_strings[2]) assert not pat.fullmatch(test_strings[3]) assert IDX_LAMBDA_RE.fullmatch(test_strings[4]) - assert not IDX_LAMBDA_INAME.fullmatch(test_strings[4]) - assert IDX_LAMBDA_JUST_REDUCTIONS.fullmatch(test_strings[4]) + assert not IDX_LAMBDA_RESERVED_INDEX_PATTERN.fullmatch(test_strings[4]) - for i in range(5, len(test_strings)): + for i in range(5, len(test_strings)-1): assert IDX_LAMBDA_RE.fullmatch(test_strings[i]) - assert IDX_LAMBDA_INAME.fullmatch(test_strings[i]) - assert not IDX_LAMBDA_JUST_REDUCTIONS.fullmatch(test_strings[i]) + assert IDX_LAMBDA_RESERVED_INDEX_PATTERN.fullmatch(test_strings[i]) + + assert not IDX_LAMBDA_RE.fullmatch(test_strings[-1]) + assert not IDX_LAMBDA_RESERVED_INDEX_PATTERN.fullmatch(test_strings[-1]) def test_cached_walk_mapper_with_extra_args(): @@ -1285,6 +1284,39 @@ def post_visit(self, expr, passed_number): # passing incorrect argument should raise TypeError while calling post_visit my_walk_mapper(dag, bad_arg_name=7) +def test_unify_axes_tags_indexlambda(): + from testlib import BarTag, BazTag, FooTag, QuuxTag, TestlibTag + + from pytato.array import EinsumReductionAxis + from pymbolic import primitives as prim + from immutabledict import immutabledict + + x = pt.make_placeholder("x", (10, 4)) + x = x.with_tagged_axis(0, FooTag()) + + y = pt.make_placeholder("y", (4, 10)) + y = y.with_tagged_axis(0, BarTag()) + + z = pt.IndexLambda(expr=prim.Subscript(prim.Variable("_in0"), + (prim.Variable("_0"), + prim.Subscript(prim.Variable("_in1"), + (prim.Variable("_1"), 0))) + ), + bindings=immutabledict({"_in0": x, "_in1": y}), + dtype=float, axes=pt.array._get_default_axes(2), + tags=pt.array._get_default_tags(), + shape=(10,4), + var_to_reduction_descr=immutabledict({})) + + z_unified = pt.unify_axes_tags(z) + + assert z_unified.axes[0].tags_of_type(TestlibTag) == frozenset([FooTag()]) + assert z_unified.axes[1].tags_of_type(TestlibTag) == frozenset([BarTag()]) + + assert z_unified.bindings["_in1"].axes[0].tags_of_type(TestlibTag) == frozenset([BarTag()]) + + assert z_unified.bindings["_in0"].axes[0].tags_of_type(TestlibTag) == frozenset([FooTag()]) + assert z_unified.bindings["_in1"].axes[1].tags_of_type(TestlibTag) == frozenset([]) def test_unify_axes_tags(): from testlib import BarTag, BazTag, FooTag, QuuxTag, TestlibTag From d967d5d7c77ed46d0d61af471ae9ff4a9bc7e668 Mon Sep 17 00:00:00 2001 From: Nick Date: Mon, 27 Jan 2025 03:16:34 -0600 Subject: [PATCH 28/28] Preconvert to using IndexLambdas. So need to get all of the cases handled in IndexLambdas. This propogates metadata through partially indexed axes of arrays. --- pytato/transform/metadata.py | 27 +++++++++++++++++++++++++++ test/test_pytato.py | 28 ++++++++++++++++++---------- 2 files changed, 45 insertions(+), 10 deletions(-) diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 65c6fd147..6082317c9 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -77,6 +77,7 @@ from pytato.scalar_expr import ( IDX_LAMBDA_RESERVED_INDEX_PATTERN, CombineMapper, + get_dependencies as get_dependencies_scalar ) from pytato.transform import ArrayOrNames, CopyMapper, Mapper from pytato.utils import are_shape_components_equal, are_shapes_equal @@ -290,14 +291,19 @@ def map_index_lambda(self, expr: IndexLambda) -> None: iname format, "_[0-9]+", and the axis of the output specified by the iname. """ for bnd in expr.bindings.values(): + breakpoint() self.rec(bnd) index_expr_used = BindingSubscriptsCollector()(expr.expr) + + breakpoint() for vname, set_of_ind_tuple in index_expr_used.items(): for ind_tuple in set_of_ind_tuple: for axis_ind, var_ind_name in enumerate(ind_tuple): + + variables_used = get_dependencies_scalar(var_ind_name) if isinstance(var_ind_name, prim.Variable): lhs: str = self.get_var_for_axis(expr.bindings[vname], axis_ind) @@ -325,6 +331,27 @@ def map_index_lambda(self, expr: IndexLambda) -> None: #warning("Variable does not match an index pattern. It will #be ignored for metadata propagation.") + # We need to add an equation if the index name is the only variable + # for that axis. This includes if there is scaled indexing. + for ind_name in variables_used: + breakpoint() + lhs: str = self.get_var_for_axis(expr.bindings[vname], + axis_ind) + matched_pattern = IDX_LAMBDA_RESERVED_INDEX_PATTERN.fullmatch(ind_name) + if matched_pattern: + # matched with an axis index of the output. + self.record_equation(lhs, self.get_var_for_axis(expr, + int(matched_pattern.group("index")))) + elif ind_name in expr.var_to_reduction_descr.keys(): + # matched with a reduction axis. + # We are assuming that this axis is eliminated from the + # axes of the output array. So, the metadata will only be keep + # in the reduction descriptor object which is indexed by the + # var_ind_name.name + self.record_equation(lhs, + self.get_var_for_axis(expr, ind_name)) + + return def map_stack(self, expr: Stack) -> None: diff --git a/test/test_pytato.py b/test/test_pytato.py index ec553c342..2b4a2f64e 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1284,6 +1284,7 @@ def post_visit(self, expr, passed_number): # passing incorrect argument should raise TypeError while calling post_visit my_walk_mapper(dag, bad_arg_name=7) +""" def test_unify_axes_tags_indexlambda(): from testlib import BarTag, BazTag, FooTag, QuuxTag, TestlibTag @@ -1317,6 +1318,7 @@ def test_unify_axes_tags_indexlambda(): assert z_unified.bindings["_in0"].axes[0].tags_of_type(TestlibTag) == frozenset([FooTag()]) assert z_unified.bindings["_in1"].axes[1].tags_of_type(TestlibTag) == frozenset([]) +""" def test_unify_axes_tags(): from testlib import BarTag, BazTag, FooTag, QuuxTag, TestlibTag @@ -1332,6 +1334,8 @@ def test_unify_axes_tags(): y = pt.expand_dims(x, (2, 3)) + x y_unified = pt.unify_axes_tags(y) + + assert isinstance(y_unified, pt.IndexLambda) assert (y_unified.axes[0].tags_of_type(TestlibTag) == frozenset([FooTag()])) assert (y_unified.axes[2].tags_of_type(TestlibTag) @@ -1354,20 +1358,23 @@ def test_unify_axes_tags(): z = pt.einsum("ij, ij -> i", x, y) z_unified = pt.unify_axes_tags(z) + assert isinstance(z_unified, pt.IndexLambda) assert (z_unified.axes[0].tags_of_type(TestlibTag) == frozenset([FooTag()])) - assert (z_unified.args[0].axes[1].tags_of_type(TestlibTag) + assert (z_unified.bindings["_in0"].axes[1].tags_of_type(TestlibTag) == frozenset([BarTag()])) - assert (z_unified.args[1].axes[0].tags_of_type(TestlibTag) + assert (z_unified.bindings["_in1"].axes[0].tags_of_type(TestlibTag) == frozenset([FooTag()])) - assert (z_unified.redn_axis_to_redn_descr[EinsumReductionAxis(0)] + + keys = list(z_unified.var_to_reduction_descr.keys()) + assert len(keys) == 1 + assert (z_unified.var_to_reduction_descr[keys[0]] .tags_of_type(TestlibTag) == frozenset([BarTag()])) # }}} # {{{ 3. advanced indexing - idx1 = pt.make_placeholder("idx1", (42, 1), "int32") idx1 = idx1.with_tagged_axis(0, FooTag()) @@ -1387,14 +1394,15 @@ def test_unify_axes_tags(): assert (y_unified.axes[0].tags_of_type(TestlibTag) == frozenset([BazTag()])) assert (y_unified.axes[1].tags_of_type(TestlibTag) - == frozenset()) + == frozenset([QuuxTag()])) + # A portion of an axis still has the same units as the whole axis. assert (y_unified.axes[2].tags_of_type(TestlibTag) - == frozenset([FooTag()])) + == frozenset([FooTag(), QuuxTag()])) assert (y_unified.axes[3].tags_of_type(TestlibTag) - == frozenset([BarTag()])) + == frozenset([BarTag(), QuuxTag()])) assert (y_unified.axes[4].tags_of_type(TestlibTag) == frozenset([QuuxTag()])) - + # }}} # {{ Reduction Operations with IndexLambda @@ -1499,7 +1507,7 @@ def get_def_reduction_descrs(): # }}} # }} - +""" def test_unify_axes_tags_with_unbroadcastable_expressions(): a = pt.make_placeholder("a", (512, 10, 8)) @@ -1535,7 +1543,7 @@ def test_unify_axes_tags_with_unbroadcastable_expressions(): term = z_unified.bindings[key] assert (term.axes[0].tags_of_type(TestlibTag) == frozenset([BazTag()])) assert (term.axes[1].tags_of_type(TestlibTag) == frozenset([QuuxTag()])) - +""" def test_ignoring_axes_during_propagation(): from pytools.tag import UniqueTag