diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index e74007c0f..efabbdc33 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -163,7 +163,8 @@ def map_reduce(self, expr: Reduce) -> ScalarExpression: for name, bound in expr.bounds.items()})) -IDX_LAMBDA_RE = re.compile(r"_r?(0|([1-9][0-9]*))") +IDX_LAMBDA_RE = re.compile(r"^(_r?(?P0|[1-9][0-9]*))$") +IDX_LAMBDA_RESERVED_INDEX_PATTERN = re.compile(r"^(_(?P0|[1-9][0-9]*))$") class DependencyMapper(DependencyMapperBase[P]): diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 8cd520d94..6082317c9 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -38,17 +38,20 @@ THE SOFTWARE. """ - import logging +import re from typing import ( TYPE_CHECKING, Any, + ParamSpec, + TypeAlias, TypeVar, cast, ) from bidict import bidict +import pymbolic.primitives as prim from pytools import UniqueNameGenerator from pytools.tag import Tag @@ -56,7 +59,6 @@ AbstractResultWithNamedArrays, AdvancedIndexInContiguousAxes, Array, - ArrayOrScalar, AxisPermutation, BasicIndex, Concatenate, @@ -70,18 +72,13 @@ Reshape, Stack, ) -from pytato.diagnostic import UnknownIndexLambdaExpr from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder -from pytato.raising import ( - BinaryOp, - BroadcastOp, - C99CallOp, - FullOp, - ReduceOp, - WhereOp, - index_lambda_to_high_level_op, +from pytato.function import NamedCallResult +from pytato.scalar_expr import ( + IDX_LAMBDA_RESERVED_INDEX_PATTERN, + CombineMapper, + get_dependencies as get_dependencies_scalar ) -from pytato.scalar_expr import SCALAR_CLASSES from pytato.transform import ArrayOrNames, CopyMapper, Mapper from pytato.utils import are_shape_components_equal, are_shapes_equal @@ -90,7 +87,7 @@ if TYPE_CHECKING: - from collections.abc import Collection, Mapping + from collections.abc import Collection, Iterable, Mapping from pytato.function import NamedCallResult from pytato.loopy import LoopyCall @@ -98,9 +95,71 @@ GraphNodeT = TypeVar("GraphNodeT") +BindingName: TypeAlias = str +P = ParamSpec("P") + +BINDING_NAME_RESERVED_PATTERN = re.compile(r"^(_in?(0|([1-9][0-9]*)))$") + + +# {{{ BindingSubscriptsCollector + + +class BindingSubscriptsCollector(CombineMapper[dict[BindingName, + set[tuple[prim.Expression, ...]]], + []]): + """ + Return all the subscript expressions used by a variable specified by BindingName. + Ex: + _in1[_0,_1] would result in an dictionary entry {"_in1": ("_0", "_1")}. + """ + def combine(self, + values: Iterable[dict[BindingName, + set[tuple[prim.Expression, ...]]]]) \ + -> dict[BindingName, set[tuple[prim.Expression, ...]]]: + out: dict[BindingName, set[tuple[prim.Expression, ...]]] = {} + import operator + from functools import reduce + return reduce(operator.or_, values, out) + + def map_subscript(self, expr: prim.Subscript) -> dict[BindingName, + set[tuple[prim.Expression, ...]]]: + """ + Record the indexing expression if the Subscript expression has a prim.Variable + as its aggregate. + """ + + if isinstance(expr.aggregate, prim.Variable): + name = expr.aggregate.name + base: dict[BindingName, + set[tuple[prim.Expression, ...]]] = {name: {expr.index_tuple}} + + """ + for ind, subexpr in enumerate(expr.index_tuple): + sub = self.rec(subexpr) + if sub: + # we have nested subscripts. + for key, val in sub.items(): + # The new key will be comma separated. + newkey = name + "," + str(ind) + "," + key + base.update({newkey: val}) + """ + return self.combine([base] + [self.rec(subexpr) for _, subexpr in enumerate(expr.index_tuple)]) + #return {expr.aggregate.name: {expr.index_tuple}} + return {} + + def map_algebraic_leaf(self, expr: prim.Expression) -> dict[BindingName, + set[tuple[prim.Expression, ...]]]: + + return {} + + def map_constant(self, expr: object) -> dict[BindingName, + set[tuple[prim.Expression, ...]]]: + return {} +# }}} # {{{ AxesTagsEquationCollector + class AxesTagsEquationCollector(Mapper[None, []]): r""" Records equations arising from operand/output axes equivalence for an array @@ -156,7 +215,7 @@ def __init__(self, tag_t: type[Tag]) -> None: # axis_to_var: mapping from (array, iaxis) to the variable to be # used for unification. - self.axis_to_var: bidict[tuple[Array, int], str] = bidict() + self.axis_to_var: bidict[tuple[Array, int | str], str] = bidict() self.known_tag_to_var: dict[Tag, str] = {} self.equations: list[tuple[str, str]] = [] @@ -165,7 +224,7 @@ def __init__(self, tag_t: type[Tag]) -> None: # {{{ unification helpers - def get_var_for_axis(self, ary: Array, iaxis: int) -> str: + def get_var_for_axis(self, ary: Array, iaxis: int | str) -> str: key = (ary, iaxis) try: @@ -227,75 +286,73 @@ def _map_input_base(self, expr: InputArgumentBase) -> None: def map_index_lambda(self, expr: IndexLambda) -> None: """ - The propagation semantics for a :class:`~pytato.IndexLambda` are - implemented only for operations that can be raised to a - :class:`~pytato.raising.HighLevelOp`. In such cases, an equality - equation is recorded for every non-broadcasted axis of an operand and - its corresponding axis of *expr*. + Equality conditions are added between an axis of the operands which is indexed + by a :class:`~pymbolic.Variable` which has a name that follows the reserved + iname format, "_[0-9]+", and the axis of the output specified by the iname. """ for bnd in expr.bindings.values(): + breakpoint() self.rec(bnd) - try: - hlo = index_lambda_to_high_level_op(expr) - except UnknownIndexLambdaExpr: - from warnings import warn - warn(f"'{expr}' is an unknown index lambda type" - " no tags were propagated across it.", stacklevel=1) - # no propagation semantics implemented for such cases - return - - if isinstance(hlo, BinaryOp): - subexprs: tuple[ArrayOrScalar, ...] = (hlo.x1, hlo.x2) - elif isinstance(hlo, WhereOp): - subexprs = (hlo.condition, hlo.then, hlo.else_) - elif isinstance(hlo, FullOp): - # A full-op does not impose any equations - subexprs = () - elif isinstance(hlo, BroadcastOp): - subexprs = (hlo.x,) - elif isinstance(hlo, C99CallOp): - subexprs = hlo.args - elif isinstance(hlo, ReduceOp): - - # {{{ ReduceOp doesn't quite involve broadcasting - - i_out_axis = 0 - for i_in_axis in range(hlo.x.ndim): - if i_in_axis not in hlo.axes: - self.record_equation( - self.get_var_for_axis(hlo.x, i_in_axis), - self.get_var_for_axis(expr, i_out_axis) - ) - i_out_axis += 1 - - assert i_out_axis == expr.ndim - - # }}} - - return - - else: - raise NotImplementedError(type(hlo)) - - for subexpr in subexprs: - if isinstance(subexpr, Array): - for i_in_axis, i_out_axis in zip( - range(subexpr.ndim), - range(expr.ndim-subexpr.ndim, expr.ndim), - strict=True): - in_dim = subexpr.shape[i_in_axis] - out_dim = expr.shape[i_out_axis] - if are_shape_components_equal(in_dim, out_dim): - self.record_equation( - self.get_var_for_axis(subexpr, i_in_axis), - self.get_var_for_axis(expr, i_out_axis) - ) - else: - # i_in_axis is broadcasted => do not propagate - assert are_shape_components_equal(in_dim, 1) - else: - assert isinstance(subexpr, SCALAR_CLASSES) + index_expr_used = BindingSubscriptsCollector()(expr.expr) + + + + breakpoint() + for vname, set_of_ind_tuple in index_expr_used.items(): + for ind_tuple in set_of_ind_tuple: + for axis_ind, var_ind_name in enumerate(ind_tuple): + + variables_used = get_dependencies_scalar(var_ind_name) + if isinstance(var_ind_name, prim.Variable): + lhs: str = self.get_var_for_axis(expr.bindings[vname], + axis_ind) + matched_pattern = IDX_LAMBDA_RESERVED_INDEX_PATTERN.fullmatch(var_ind_name.name) + if matched_pattern: + # matched with an axis index. + self.record_equation(lhs, self.get_var_for_axis(expr, + int(matched_pattern.group("index")))) + elif var_ind_name.name in expr.var_to_reduction_descr.keys(): + # matched with a reduction axis. + # We are assuming that this axis is eliminated from the + # axes of the output array. So, the metadata will only be keep + # in the reduction descriptor object which is indexed by the + # var_ind_name.name + self.record_equation(lhs, + self.get_var_for_axis(expr, var_ind_name.name)) + + elif BINDING_NAME_RESERVED_PATTERN.fullmatch(var_ind_name.name): + # This means that we had an index of index. + # So, the metadata propagation with this index is data + # dependent. + pass + else: + pass + #warning("Variable does not match an index pattern. It will + #be ignored for metadata propagation.") + + # We need to add an equation if the index name is the only variable + # for that axis. This includes if there is scaled indexing. + for ind_name in variables_used: + breakpoint() + lhs: str = self.get_var_for_axis(expr.bindings[vname], + axis_ind) + matched_pattern = IDX_LAMBDA_RESERVED_INDEX_PATTERN.fullmatch(ind_name) + if matched_pattern: + # matched with an axis index of the output. + self.record_equation(lhs, self.get_var_for_axis(expr, + int(matched_pattern.group("index")))) + elif ind_name in expr.var_to_reduction_descr.keys(): + # matched with a reduction axis. + # We are assuming that this axis is eliminated from the + # axes of the output array. So, the metadata will only be keep + # in the reduction descriptor object which is indexed by the + # var_ind_name.name + self.record_equation(lhs, + self.get_var_for_axis(expr, ind_name)) + + + return def map_stack(self, expr: Stack) -> None: """ @@ -303,6 +360,7 @@ def map_stack(self, expr: Stack) -> None: and their corresponding axis in *expr*. No equation is added for the newly created axis i.e. :attr:`pytato.array.Stack.axis`. """ + raise NotImplementedError for ary in expr.arrays: self.rec(ary) @@ -330,6 +388,7 @@ def map_concatenate(self, expr: Concatenate) -> None: added for the concatenated axis i.e. :attr:`pytato.array.Concatenate.axis`. """ + raise NotImplementedError for ary in expr.arrays: self.rec(ary) @@ -350,6 +409,7 @@ def map_axis_permutation(self, expr: AxisPermutation its corresponding axis in *expr* as specified by :attr:`pytato.array.AxisPermutation.axis_permutation`. """ + raise NotImplementedError self.rec(expr.array) assert expr.ndim == expr.array.ndim @@ -368,6 +428,7 @@ def map_basic_index(self, expr: BasicIndex) -> None: sliced axis is one which goes along the entire length of the axis with a positive unit stride. """ + raise NotImplementedError self.rec(expr.array) i_out_axis = 0 @@ -400,6 +461,7 @@ def map_contiguous_advanced_index(self, indices adds an equality equation for each non-broadcasted axis of an indexing array to its corresponding axis in *expr*. """ + raise NotImplementedError from pytato.utils import get_shape_after_broadcasting, partition self.rec(expr.array) @@ -487,6 +549,7 @@ def map_reshape(self, expr: Reshape) -> None: output and so no constraints are enforced except when the :class:`pytato.Reshape` has come from a :func:`pytato.expand_dims`. """ + raise NotImplementedError from pytato.tags import ExpandedDimsReshape self.rec(expr.array) @@ -513,6 +576,7 @@ def map_einsum(self, expr: Einsum) -> None: :func:`pytato.einsum` thereby having the same the :class:`~pytato.array.EinsumAxisDescriptor`. """ + raise NotImplementedError from pytato.array import EinsumAxisDescriptor, EinsumElementwiseAxis for arg in expr.args: @@ -534,6 +598,7 @@ def map_einsum(self, expr: Einsum) -> None: descr_to_var[descr] = in_tag_var def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> None: + raise NotImplementedError for _, subexpr in sorted(expr._data.items()): self.rec(subexpr) @@ -541,6 +606,7 @@ def map_loopy_call(self, expr: LoopyCall) -> None: """ Does not add any equations. """ + raise NotImplementedError for _, subexpr in sorted(expr.bindings.items()): if isinstance(subexpr, Array): self.rec(subexpr) @@ -550,6 +616,7 @@ def map_loopy_call(self, expr: LoopyCall) -> None: # high level ops, but that's quite involved and probably not worth it. def map_named_array(self, expr: NamedArray) -> None: + raise NotImplementedError self.rec(expr._container) def map_distributed_send_ref_holder(self, @@ -562,6 +629,7 @@ def map_distributed_send_ref_holder(self, equations are added between each axis of *expr* and its corresponding axis in the pass-through data. """ + raise NotImplementedError self.rec(expr.passthrough_data) self.rec(expr.send.data) for idim in range(expr.ndim): @@ -576,6 +644,7 @@ def map_distributed_recv(self, :class:`pytato.DistributedRecv` does not have any operands and so no more equations are deduced. """ + raise NotImplementedError def map_named_call_result(self, expr: NamedCallResult) -> Array: raise NotImplementedError( @@ -594,10 +663,11 @@ class AxisTagAttacher(CopyMapper): A mapper that tags the axes in a DAG as prescribed by *axis_to_tags*. """ def __init__(self, - axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]], + axis_to_tags: Mapping[tuple[Array, int | str], Collection[Tag]], tag_corresponding_redn_descr: bool): super().__init__() - self.axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]] = axis_to_tags + self.axis_to_tags: Mapping[tuple[Array, int | str], + Collection[Tag]] = axis_to_tags self.tag_corresponding_redn_descr: bool = tag_corresponding_redn_descr def rec(self, expr: ArrayOrNames) -> Any: @@ -634,18 +704,15 @@ def rec(self, expr: ArrayOrNames) -> Any: if isinstance(expr, IndexLambda): assert isinstance(expr_copy, IndexLambda) - try: - hlo = index_lambda_to_high_level_op(expr) - except UnknownIndexLambdaExpr: - pass - else: - if isinstance(hlo, ReduceOp): - for iaxis, redn_var in hlo.axes.items(): - expr_copy = expr_copy.with_tagged_reduction( + if expr_copy.var_to_reduction_descr: + # This is a reduction operation. + # We need to find the axes that are reduced over + # and update the tag/tag them appropriately. + for redn_var in expr.var_to_reduction_descr.keys(): + expr_copy = expr_copy.with_tagged_reduction( redn_var, - self.axis_to_tags.get((hlo.x, iaxis), []) + self.axis_to_tags.get((expr, redn_var), []) ) - # }}} self._cache[key] = expr_copy @@ -697,7 +764,26 @@ def unify_axes_tags( """ equations_collector = equations_collector_t(tag_t) - equations_collector(expr) + # First we will convert the expression to a series of IndexLambda operations. + + from pytato.transform.lower_to_index_lambda import ToIndexLambdaMixin, to_index_lambda + from pytato.transform import TransformMapper + from pytato.diagnostic import CannotBeLoweredToIndexLambda + mapped_expr = to_index_lambda(expr) + + class MyIndexMapper(TransformMapper, ToIndexLambdaMixin): + def handle_unsupported_array(self, expr: Any) -> Any: + raise CannotBeLoweredToIndexLambda(type(expr)) + + def map_placeholder(self, expr: Placeholder) -> Placeholder: + return expr + + + mymapper = MyIndexMapper() + mapped_expr = mymapper(expr) + + breakpoint() + equations_collector(mapped_expr) # start BFS traversal with the known tags as the sources. # From the equations build a Propagation Graph @@ -710,7 +796,10 @@ def unify_axes_tags( ) known_tag_vars = frozenset(equations_collector.known_tag_to_var.values()) - axis_to_solved_tags: dict[tuple[Array, int], set[Tag]] = {} + + # Reduction axes are specified by a str but all other axes are specified + # by an integer. Note that the axes are still uniquely identified. + axis_to_solved_tags: dict[tuple[Array, int | str], set[Tag]] = {} propagation_graph = undirected_graph_from_edges( equations_collector.equations @@ -721,10 +810,12 @@ def unify_axes_tags( if isinstance(tag, AxisIgnoredForPropagationTag) }) - ignored_vars.update({ - ax_var for (ary, ax), ax_var in equations_collector.axis_to_var.items() - if ary.axes[ax].tags_of_type(AxisIgnoredForPropagationTag) - }) + for (ary, ax), ax_var in equations_collector.axis_to_var.items(): + # Reduction axes do not follow AxisIgnoredForPropagation. + # They cannot propagate the information to descendant of the array anyway. + if isinstance(ax, int): + if ary.axes[ax].tags_of_type(AxisIgnoredForPropagationTag): + ignored_vars.update({ax_var}) for tag, var in equations_collector.known_tag_to_var.items(): reachable_nodes = get_reachable_nodes(propagation_graph, var, @@ -737,6 +828,6 @@ def unify_axes_tags( return AxisTagAttacher(axis_to_solved_tags, tag_corresponding_redn_descr=unify_redn_descrs, - )(expr) + )(mapped_expr) # vim: fdm=marker diff --git a/test/test_codegen.py b/test/test_codegen.py index 5b62ce468..6784be948 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -916,11 +916,11 @@ def test_einsum_with_parameterized_shapes(ctx_factory): m_in = np.random.randint(2, 20) n_in = np.random.randint(2, 20) - def _get_a_shape(_m, _n): - return (2*_m+1, 3*_n+7) + def _get_a_shape(m_, n_): + return (2*m_+1, 3*n_+7) - def _get_x_shape(_m, _n): - return (3*_n+7, ) + def _get_x_shape(_m, n_): + return (3*n_+7, ) A_in = np.random.rand(*_get_a_shape(m_in, n_in)) # noqa: N806 x_in = np.random.rand(*_get_x_shape(m_in, n_in)) @@ -1570,8 +1570,8 @@ def test_regression_reduction_in_conditional(ctx_factory): ctx = ctx_factory() cq = cl.CommandQueue(ctx) - def kernel(usr_np, _pt_data_9): - pt_tmp_53 = _pt_data_9 @ _pt_data_9 + def kernel(usr_np, pt_data_9): + pt_tmp_53 = pt_data_9 @ pt_data_9 pt_tmp_42 = usr_np.maximum(pt_tmp_53, pt_tmp_53) pt_tmp_27 = usr_np.sum(pt_tmp_42) pt_tmp_0 = usr_np.maximum(pt_tmp_27, pt_tmp_53) @@ -1579,7 +1579,7 @@ def kernel(usr_np, _pt_data_9): def get_np_input_args(): return { - "_pt_data_9": np.ones((2, 2)), + "pt_data_9": np.ones((2, 2)), } np_inputs = get_np_input_args() diff --git a/test/test_pytato.py b/test/test_pytato.py index 45d333c32..2b4a2f64e 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1213,6 +1213,47 @@ def test_lower_to_index_lambda(): assert isinstance(binding, Reshape) +def test_reserved_binding_name_patterns(): + from pytato.transform.metadata import BINDING_NAME_RESERVED_PATTERN + + fail_strings = ["_r0", "_r000", "_r01", "_00", "_r101", "_1", "_0", "_101", + "_in", "_in00", "1_in", "_in01"] + pass_strings = ["_in0", "_in1", "_in554", "_in10"] + + for test_str in fail_strings: + assert not BINDING_NAME_RESERVED_PATTERN.fullmatch(test_str) + + for test_str in pass_strings: + assert BINDING_NAME_RESERVED_PATTERN.fullmatch(test_str) + + +def test_reserved_scalar_iname_patterns(): + from pytato.scalar_expr import ( + IDX_LAMBDA_RESERVED_INDEX_PATTERN, + IDX_LAMBDA_RE, + ) + + test_strings = ["_r0", "_r000", "_r01", "_00", "_r101", "_1", "_0", "_101", "_r"] + + assert IDX_LAMBDA_RE.fullmatch(test_strings[0]) + assert not IDX_LAMBDA_RESERVED_INDEX_PATTERN.fullmatch(test_strings[0]) + + for pat in [IDX_LAMBDA_RESERVED_INDEX_PATTERN, IDX_LAMBDA_RE]: + assert not pat.fullmatch(test_strings[1]) + assert not pat.fullmatch(test_strings[2]) + assert not pat.fullmatch(test_strings[3]) + + assert IDX_LAMBDA_RE.fullmatch(test_strings[4]) + assert not IDX_LAMBDA_RESERVED_INDEX_PATTERN.fullmatch(test_strings[4]) + + for i in range(5, len(test_strings)-1): + assert IDX_LAMBDA_RE.fullmatch(test_strings[i]) + assert IDX_LAMBDA_RESERVED_INDEX_PATTERN.fullmatch(test_strings[i]) + + assert not IDX_LAMBDA_RE.fullmatch(test_strings[-1]) + assert not IDX_LAMBDA_RESERVED_INDEX_PATTERN.fullmatch(test_strings[-1]) + + def test_cached_walk_mapper_with_extra_args(): from testlib import RandomDAGContext, make_random_dag @@ -1243,6 +1284,41 @@ def post_visit(self, expr, passed_number): # passing incorrect argument should raise TypeError while calling post_visit my_walk_mapper(dag, bad_arg_name=7) +""" +def test_unify_axes_tags_indexlambda(): + from testlib import BarTag, BazTag, FooTag, QuuxTag, TestlibTag + + from pytato.array import EinsumReductionAxis + from pymbolic import primitives as prim + from immutabledict import immutabledict + + x = pt.make_placeholder("x", (10, 4)) + x = x.with_tagged_axis(0, FooTag()) + + y = pt.make_placeholder("y", (4, 10)) + y = y.with_tagged_axis(0, BarTag()) + + z = pt.IndexLambda(expr=prim.Subscript(prim.Variable("_in0"), + (prim.Variable("_0"), + prim.Subscript(prim.Variable("_in1"), + (prim.Variable("_1"), 0))) + ), + bindings=immutabledict({"_in0": x, "_in1": y}), + dtype=float, axes=pt.array._get_default_axes(2), + tags=pt.array._get_default_tags(), + shape=(10,4), + var_to_reduction_descr=immutabledict({})) + + z_unified = pt.unify_axes_tags(z) + + assert z_unified.axes[0].tags_of_type(TestlibTag) == frozenset([FooTag()]) + assert z_unified.axes[1].tags_of_type(TestlibTag) == frozenset([BarTag()]) + + assert z_unified.bindings["_in1"].axes[0].tags_of_type(TestlibTag) == frozenset([BarTag()]) + + assert z_unified.bindings["_in0"].axes[0].tags_of_type(TestlibTag) == frozenset([FooTag()]) + assert z_unified.bindings["_in1"].axes[1].tags_of_type(TestlibTag) == frozenset([]) +""" def test_unify_axes_tags(): from testlib import BarTag, BazTag, FooTag, QuuxTag, TestlibTag @@ -1258,6 +1334,8 @@ def test_unify_axes_tags(): y = pt.expand_dims(x, (2, 3)) + x y_unified = pt.unify_axes_tags(y) + + assert isinstance(y_unified, pt.IndexLambda) assert (y_unified.axes[0].tags_of_type(TestlibTag) == frozenset([FooTag()])) assert (y_unified.axes[2].tags_of_type(TestlibTag) @@ -1280,20 +1358,23 @@ def test_unify_axes_tags(): z = pt.einsum("ij, ij -> i", x, y) z_unified = pt.unify_axes_tags(z) + assert isinstance(z_unified, pt.IndexLambda) assert (z_unified.axes[0].tags_of_type(TestlibTag) == frozenset([FooTag()])) - assert (z_unified.args[0].axes[1].tags_of_type(TestlibTag) + assert (z_unified.bindings["_in0"].axes[1].tags_of_type(TestlibTag) == frozenset([BarTag()])) - assert (z_unified.args[1].axes[0].tags_of_type(TestlibTag) + assert (z_unified.bindings["_in1"].axes[0].tags_of_type(TestlibTag) == frozenset([FooTag()])) - assert (z_unified.redn_axis_to_redn_descr[EinsumReductionAxis(0)] + + keys = list(z_unified.var_to_reduction_descr.keys()) + assert len(keys) == 1 + assert (z_unified.var_to_reduction_descr[keys[0]] .tags_of_type(TestlibTag) == frozenset([BarTag()])) # }}} # {{{ 3. advanced indexing - idx1 = pt.make_placeholder("idx1", (42, 1), "int32") idx1 = idx1.with_tagged_axis(0, FooTag()) @@ -1313,16 +1394,156 @@ def test_unify_axes_tags(): assert (y_unified.axes[0].tags_of_type(TestlibTag) == frozenset([BazTag()])) assert (y_unified.axes[1].tags_of_type(TestlibTag) - == frozenset()) + == frozenset([QuuxTag()])) + # A portion of an axis still has the same units as the whole axis. assert (y_unified.axes[2].tags_of_type(TestlibTag) - == frozenset([FooTag()])) + == frozenset([FooTag(), QuuxTag()])) assert (y_unified.axes[3].tags_of_type(TestlibTag) - == frozenset([BarTag()])) + == frozenset([BarTag(), QuuxTag()])) assert (y_unified.axes[4].tags_of_type(TestlibTag) == frozenset([QuuxTag()])) + + # }}} + + # {{ Reduction Operations with IndexLambda + # {{{ Reduce on outside of scalar expression + + from immutabledict import immutabledict + + import pymbolic.primitives as prim + + def setup_test(): + a = pt.make_placeholder("a", (512)) + b = pt.make_placeholder("b", (512, 10)) + c = pt.make_placeholder("c", (512, 10)) + c = c.with_tagged_axis(1, QuuxTag()) + b = b.with_tagged_axis(1, FooTag()) + a = a.with_tagged_axis(0, BazTag()) + + x = prim.Subscript(prim.Variable("_in0"), (prim.Variable("_0"))) + y = prim.Subscript(prim.Variable("_in1"), + (prim.Variable("_0"), prim.Variable("_r0"))) + z = prim.Subscript(prim.Variable("_in2"), (prim.Variable("_0"), + prim.Variable("_r1"))) + + return a, b, c, x, y, z + + def assert_tags_were_propagated_appropriately(arr): + assert arr.var_to_reduction_descr["_r0"].tags_of_type(TestlibTag) == \ + frozenset([FooTag()]) + assert arr.var_to_reduction_descr["_r1"].tags_of_type(TestlibTag) == \ + frozenset([QuuxTag()]) + + assert arr.axes[0].tags_of_type(TestlibTag) == frozenset([BazTag()]) + for key in ["_in" + str(i) for i in range(2)]: + assert arr.bindings[key].axes[0].tags_of_type(TestlibTag) == \ + frozenset([BazTag()]) + + def get_def_reduction_descrs(): + return immutabledict({"_r0": pt.array.ReductionDescriptor(frozenset([])), + "_r1": pt.array.ReductionDescriptor(frozenset([])) + }) + + a, b, c, x, y, z = setup_test() + # sum((_r0, _r1), a[_0] + b[_0, _r0] + b[_0,_r1])) + w = pt.IndexLambda(expr=pt.scalar_expr.Reduce(prim.Sum((x, y, z)), + pt.reductions.SumReductionOperation, + immutabledict({"_r0": (0, 10), + "_r1": (0, 10)})), + bindings=immutabledict({"_in0": a, "_in1": b, "_in2": c}), + shape=(512,), tags=pt.array._get_default_tags(), + axes=pt.array._get_default_axes(1), + dtype=float, + var_to_reduction_descr=get_def_reduction_descrs()) + + w_unified = pt.unify_axes_tags(w) + + assert_tags_were_propagated_appropriately(w_unified) + + # }}} Reduction on the outside of the scalar expression. + + # {{{ Side-by-Side reduction. + + a, b, c, x, y, z = setup_test() + + # a[_0] + sum(_r0, b[_0, _r0]) + sum(_r1, b[_0,_r1]) + + w = pt.IndexLambda(expr=prim.Sum((x, pt.scalar_expr.Reduce(y, + pt.reductions.SumReductionOperation, + immutabledict({"_r0": (0, 10)})), + pt.scalar_expr.Reduce(z, + pt.reductions.SumReductionOperation, + immutabledict({"_r1": (0, 10)})))), + bindings=immutabledict({"_in0": a, "_in1": b, "_in2": c}), + shape=(512,), tags=pt.array._get_default_tags(), + axes=pt.array._get_default_axes(1), + dtype=float, + var_to_reduction_descr=get_def_reduction_descrs()) + + w_unified = pt.unify_axes_tags(w) + assert_tags_were_propagated_appropriately(w_unified) + + # }}} + + # {{{ Nested Reductions. + # a[_0] + sum(_r0, b[_0, _r0] + sum(_r1, b[_0,_r1])) + a, b, c, x, y, z = setup_test() + + w = pt.IndexLambda(expr=prim.Sum((x, pt.scalar_expr.Reduce(prim.Sum((y, + pt.scalar_expr.Reduce(z, + pt.reductions.SumReductionOperation, + immutabledict({"_r1": (0, 10)})))), + pt.reductions.SumReductionOperation, + immutabledict({"_r0": (0, 10)})))), + bindings=immutabledict({"_in0": a, "_in1": b, "_in2": c}), + shape=(512,), tags=pt.array._get_default_tags(), + axes=pt.array._get_default_axes(1), + dtype=float, + var_to_reduction_descr=get_def_reduction_descrs()) + + w_unified = pt.unify_axes_tags(w) + assert_tags_were_propagated_appropriately(w_unified) # }}} + # }} + +""" +def test_unify_axes_tags_with_unbroadcastable_expressions(): + + a = pt.make_placeholder("a", (512, 10, 8)) + b = pt.make_placeholder("b", (512, 10)) + from testlib import BazTag, FooTag, QuuxTag, TestlibTag + + a = a.with_tagged_axis(0, BazTag()) + a = a.with_tagged_axis(1, QuuxTag()) + a = a.with_tagged_axis(2, FooTag()) + + from immutabledict import immutabledict + import pymbolic.primitives as prim + + x = prim.Subscript(prim.Variable("_in0"), (prim.Variable("_0"), prim.Variable("_1"), + prim.Variable("_2"))) + y = prim.Subscript(prim.Variable("_in1"), + (prim.Variable("_0"), prim.Variable("_1"))) + + z = pt.IndexLambda(expr=x+y, bindings=immutabledict({"_in0": a, "_in1": b}), + shape=(512, 10, 8), tags=pt.array._get_default_tags(), + axes=pt.array._get_default_axes(3), + dtype=float, + var_to_reduction_descr=immutabledict({})) + + z_unified = pt.unify_axes_tags(z) + + assert (z_unified.axes[0].tags_of_type(TestlibTag) == frozenset([BazTag()])) + assert (z_unified.axes[1].tags_of_type(TestlibTag) == frozenset([QuuxTag()])) + assert (z_unified.axes[2].tags_of_type(TestlibTag) == frozenset([FooTag()])) + + for key in z_unified.bindings.keys(): + term = z_unified.bindings[key] + assert (term.axes[0].tags_of_type(TestlibTag) == frozenset([BazTag()])) + assert (term.axes[1].tags_of_type(TestlibTag) == frozenset([QuuxTag()])) +""" def test_ignoring_axes_during_propagation(): from pytools.tag import UniqueTag