From 81846a59f46810af0ed463e5088c47ad2b9ac43e Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sun, 8 Dec 2024 11:54:42 +0100 Subject: [PATCH 01/10] Move subtensor lift rewrites to their own module --- pytensor/tensor/rewriting/__init__.py | 1 + pytensor/tensor/rewriting/subtensor.py | 418 ------------- pytensor/tensor/rewriting/subtensor_lift.py | 448 ++++++++++++++ tests/tensor/rewriting/test_subtensor.py | 531 +---------------- tests/tensor/rewriting/test_subtensor_lift.py | 561 ++++++++++++++++++ 5 files changed, 1015 insertions(+), 944 deletions(-) create mode 100644 pytensor/tensor/rewriting/subtensor_lift.py create mode 100644 tests/tensor/rewriting/test_subtensor_lift.py diff --git a/pytensor/tensor/rewriting/__init__.py b/pytensor/tensor/rewriting/__init__.py index 4e75140ceb..6d411d3827 100644 --- a/pytensor/tensor/rewriting/__init__.py +++ b/pytensor/tensor/rewriting/__init__.py @@ -14,4 +14,5 @@ import pytensor.tensor.rewriting.shape import pytensor.tensor.rewriting.special import pytensor.tensor.rewriting.subtensor +import pytensor.tensor.rewriting.subtensor_lift import pytensor.tensor.rewriting.uncanonicalize diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index 4b824e46cf..ca27761319 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -1,6 +1,5 @@ import itertools import sys -from collections.abc import Iterable import numpy as np @@ -19,11 +18,9 @@ from pytensor.tensor.basic import ( Alloc, Join, - MakeVector, ScalarFromTensor, TensorFromScalar, alloc, - as_tensor, cast, concatenate, get_scalar_constant_value, @@ -35,11 +32,8 @@ from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import ( - Dot, add, and_, - ceil_intdiv, - dot, eq, ge, gt, @@ -57,13 +51,8 @@ register_stabilize, ) from pytensor.tensor.shape import ( - Shape, - SpecifyShape, - Unbroadcast, shape_padleft, shape_tuple, - specify_shape, - unbroadcast, ) from pytensor.tensor.sharedvar import TensorSharedVariable from pytensor.tensor.subtensor import ( @@ -77,7 +66,6 @@ advanced_subtensor, advanced_subtensor1, as_index_constant, - as_index_literal, get_canonical_form_slice, get_constant_idx, get_idx_list, @@ -276,64 +264,6 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node): return [new_res] -@register_canonicalize -@register_stabilize -@register_specialize -@node_rewriter([Subtensor]) -def local_subtensor_of_dot(fgraph, node): - """Rewrite ``at.dot(A, B)[idxs]`` into ``at.dot(A[idxs_a], B[idxs_b])``. - ``idxs_a`` is the first ``A.ndim-1`` entries of ``idxs``, and ``idxs_b`` is - the remaining entries of ``idxs`` (if any), modified to skip the - second-to-last dimension of ``B`` (because dot sums over this dimension). - """ - if not isinstance(node.op, Subtensor): - return - if not (node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Dot)): - return - # If there is other node that use the outputs of the dot - # We don't want to compute twice the sub part. - if len(fgraph.clients[node.inputs[0]]) > 1: - return - - a = node.inputs[0].owner.inputs[0] - b = node.inputs[0].owner.inputs[1] - - idx_list = get_idx_list(node.inputs, node.op.idx_list) - - num_a_indices = min(a.ndim - 1, len(idx_list)) - a_indices = idx_list[:num_a_indices] - b_indices = idx_list[num_a_indices:] - - # This is necessary because np.dot sums the last index of a with the second to last of b - # so we want to skip the second-to-last index into b. - # This wasn't necessary for a, because we just omitted the last index. - # We skip this if b.ndim = 1, since then we just want b_sub = b, not b_sub = b[:] - # (dot also handles b.ndim < 2 as a special case) - if b.ndim > 1 and len(b_indices) >= b.ndim - 1: - b_indices = ( - b_indices[: b.ndim - 2] - + (slice(None, None, None),) - + b_indices[b.ndim - 2 :] - ) - - a_sub = a.__getitem__(tuple(a_indices)) - b_sub = b.__getitem__(tuple(b_indices)) if b_indices else b - - # Copy over previous output stacktrace to a_sub and b_sub, - # because an error in the subtensor operation (e.g. an index error) - # on either a or b must correspond to an error in the - # subtensor operation on their dot product. - copy_stack_trace(node.outputs[0], [a_sub, b_sub]) - - # Copy over previous output stacktrace and previous dot product stacktrace, - # because an error here may correspond to an either in either the original - # dot product, or in the dot product after the subtensor operation. - r = dot(a_sub, b_sub) - copy_stack_trace([node.outputs[0], node.inputs[0]], r) - - return [r] - - @register_infer_shape @register_useless @register_canonicalize @@ -419,110 +349,6 @@ def local_useless_slice(fgraph, node): return [out] -# fast_compile to allow opt subtensor(cast{float32}(make_vector)) -@register_canonicalize("fast_compile") -@node_rewriter([Subtensor]) -def local_subtensor_lift(fgraph, node): - """ - unary(x)[idx] -> unary(x[idx])#any broadcast pattern. - - Handles the following unary ops: - elemwise(x,...)[idx] -> elemwise(x[idx],...) - when x,... are broadcasted scalar or not broadcasted at all - Unbroadcast(x)[idx] => Unbroadcast(x[idx]) - - """ - if isinstance(node.op, Subtensor): - u = node.inputs[0] - if u.owner is None or len(fgraph.clients[u]) > 1: - return False - - if isinstance(u.owner.op, Elemwise) and len(u.owner.inputs) == 1: - idx = node.inputs[1:] - x_idx = node.op(u.owner.inputs[0], *idx) - # Copy over previous output stacktrace - copy_stack_trace(node.outputs, x_idx) - ret = u.owner.op(x_idx) - # Copy over previous output stacktrace - # and stacktrace from previous unary operation - copy_stack_trace([node.outputs[0], node.inputs[0]], ret) - return [ret] - - if isinstance(u.owner.op, Elemwise): - new_inputs = [] - if all(sum(i.type.broadcastable) == 0 for i in u.owner.inputs): - # There is no broadcastable in the inputs - idx = node.inputs[1:] - new_inputs = [node.op(i, *idx) for i in u.owner.inputs] - # Copy over previous output stacktrace - copy_stack_trace(node.outputs[0], new_inputs) - - ret = u.owner.op(*new_inputs) - # Copy over previous output stacktrace - # and stacktrace from previous unary operation - copy_stack_trace([node.outputs[0], node.inputs[0]], ret) - return [ret] - elif all(sum(i.type.broadcastable) in [i.ndim, 0] for i in u.owner.inputs): - # There is no broadcastable in the inputs or it is scalar - idx = node.inputs[1:] - new_inputs = [] - for i in u.owner.inputs: - if sum(i.type.broadcastable) == 0: - new_inputs.append(node.op(i, *idx)) - else: - # If the subtensor remove some dims, we must - # lower the number of dimensions of this scalar. - if node.outputs[0].ndim == i.ndim: - new_inputs.append(i) - else: - new_inputs.append( - i.dimshuffle(["x"] * node.outputs[0].ndim) - ) - - # Copy over previous output stacktrace - copy_stack_trace(node.outputs[0], new_inputs) - - ret = u.owner.op(*new_inputs) - # Copy over previous output stacktrace - # and stacktrace from previous unary operation - copy_stack_trace([node.outputs[0], node.inputs[0]], ret) - return [ret] - - if isinstance(u.owner.op, Unbroadcast): - # Subtensor might reduce dim., adapt broadcast pattern accordingly - old_axes = u.owner.op.axes - new_axes = [] - - # loop through indices being subtensor-ed - # i indexes broadcastable pattern before subtensor - # j indexes broadcastable pattern after subtensor - j = 0 - for i, x in enumerate(node.op.idx_list): - # if it is not a slice, it will reduce the dimension, should - # not appear in the broascastable dimensions - if isinstance(x, slice): - if i in old_axes: - new_axes.append(j) - j += 1 - # now keep the broadcastable pattern of all - # items not appearing in subtensor list - for i in range(len(node.op.idx_list), len(u.broadcastable)): - if i in old_axes: - new_axes.append(j) - j += 1 - - subt_x = node.op(u.owner.inputs[0], *node.inputs[1:]) - # Copy over previous output stacktrace - copy_stack_trace(node.outputs[0], subt_x) - - rbcast_subt_x = unbroadcast(subt_x, *new_axes) - # Copy over previous output stacktrace - # and stacktrace from previous unary operation - copy_stack_trace([node.outputs[0], node.inputs[0]], rbcast_subt_x) - - return [rbcast_subt_x] - - @register_canonicalize @register_specialize @node_rewriter([Subtensor]) @@ -653,76 +479,6 @@ def local_subtensor_remove_broadcastable_index(fgraph, node): return [node.inputs[0].dimshuffle(tuple(remain_dim))] -@register_infer_shape -@register_useless -@register_canonicalize -@register_specialize -@node_rewriter([Subtensor]) -def local_subtensor_of_alloc(fgraph, node): - """ - - alloc(val)[x:y] -> alloc(val[...]) - alloc(val)[x:y] -> alloc(val) - This can be seen as a lift, but it also reduce the number of computation/memory. - - """ - if not isinstance(node.op, Subtensor): - return False - u = node.inputs[0] - if u.owner is None: - return False - if not isinstance(u.owner.op, Alloc): - return False - slices = get_idx_list(node.inputs, node.op.idx_list) - val = u.owner.inputs[0] - dims = u.owner.inputs[1:] - assert len(slices) <= len(dims) - - # Number of dimensions added to val - n_added_dims = u.ndim - val.ndim - # Dimensions of the returned alloc - nw_dims = [] - # Slices to take from val - val_slices = [] - - for i, (sl, dim) in enumerate(zip(slices, dims, strict=False)): - # If val was not copied over that dim, - # we need to take the appropriate subtensor on it. - if i >= n_added_dims: - # We check that the corresponding val dimensions was - # not a broadcasted dimensions. - if ( - val.type.ndim > (i - n_added_dims) - and val.type.broadcastable[i - n_added_dims] - ): - val_slices.append(slice(None)) - else: - val_slices.append(sl) - - csl, _ = get_canonical_form_slice(sl, dim) - if type(csl) is not slice: - # That dimension is removed. - pass - else: - nw_dim = csl.stop - csl.start - - if csl.step != 1: - # Do not add the ceil_intdiv() graphs in the graphs - # when this is not needed as it prevent detecting the - # correct broadcast pattern. - nw_dim = ceil_intdiv(nw_dim, csl.step) - nw_dims += [nw_dim] - - nw_val = val[tuple(val_slices)] - nw_dims += dims[len(slices) :] - if nw_val.ndim > len(nw_dims): - return False - rval = alloc(nw_val, *nw_dims) - if not isinstance(rval, list | tuple): - rval = [rval] - return rval - - @register_specialize @register_canonicalize @node_rewriter([Subtensor]) @@ -762,91 +518,6 @@ def local_subtensor_inc_subtensor(fgraph, node): return -@register_infer_shape -@register_specialize -@register_canonicalize("fast_compile") -@register_useless -@node_rewriter([Subtensor, AdvancedSubtensor1]) -def local_subtensor_make_vector(fgraph, node): - """Perform ``*Subtensor*`` operations on ``MakeVector`` outputs when the indices are constant. - - Replace all ``Subtensor`` and ``MakeVector`` cases like: - [a,b,c][0] -> a - [a,b,c][0:2] -> [a,b] - - Replace all ``AdvancedSubtensor1`` and ``MakeVector`` cases like: - [a,b,c][[0,2]] -> [a,c] - - We can do this for constant indexes. - - .. note: - - This optimization implicitly relies on shape optimizations. - - TODO: This only applies to a single indexed dimension; we should have - something more general for constant ``*Subtensor*`` graphs (or perhaps - include this kind of work in the constant folding). - """ - - if not isinstance(node.op, Subtensor | AdvancedSubtensor1): - return False - - x = node.inputs[0] - - if not (x.owner and isinstance(x.owner.op, MakeVector)): - return False - - make_vector_op = x.owner.op - - if isinstance(node.op, Subtensor): - idxs = node.op.idx_list - - # Subtensor has no indexes, return make_vector - if not idxs: - return [x] - - (idx,) = idxs - - if isinstance(idx, ps.ScalarType | TensorType): - old_idx, idx = idx, node.inputs[1] - assert idx.type.is_super(old_idx) - elif isinstance(node.op, AdvancedSubtensor1): - idx = node.inputs[1] - - if isinstance(idx, int | np.integer): - return [x.owner.inputs[idx]] - elif isinstance(idx, Variable): - if idx.ndim == 0: - try: - v = get_underlying_scalar_constant_value( - idx, only_process_constants=True - ) - try: - ret = [x.owner.inputs[v]] - except IndexError: - raise NotScalarConstantError("Bad user graph!") - return ret - except NotScalarConstantError: - pass - elif idx.ndim == 1 and isinstance(idx, Constant): - values = list(map(int, list(idx.value))) - ret = make_vector_op(*[x.owner.inputs[v] for v in values]) - copy_stack_trace(node.outputs[0], ret) - return [ret] - elif isinstance(idx, slice): - # The index is a slice. If it's a constant slice, we can perform the - # index operation here. - try: - const_slice = get_constant_idx( - node.op.idx_list, node.inputs, allow_partial=False - )[0] - ret = make_vector_op(*x.owner.inputs[const_slice]) - copy_stack_trace(node.outputs, ret) - return [ret] - except NotScalarConstantError: - pass - - @register_infer_shape @register_useless @register_canonicalize @@ -1635,95 +1306,6 @@ def local_useless_inc_subtensor_alloc(fgraph, node): return [r] -@register_specialize -@register_canonicalize -@node_rewriter([Subtensor]) -def local_subtensor_shape_constant(fgraph, node): - r"""Simplify constant `Subtensor`\s on `Shape`\s dimensions that are known. - - We want to convert graphs like - - Subtensor{int64} [id A] '' - |Shape [id B] '' - | | [id C] - |ScalarConstant{0} [id D] - - into - - TensorConstant{1} - - TODO: Something like `local_shape_to_shape_i` should be a general - canonicalization, and not a `ShapeFeature`-dependent rewrite. If that were - the case, we could change this to only operate on `Shape_i`\s. - Currently, we're not handling them because they should only appear when - `ShapeFeature` is present, and it will also simplify/remove them. - - """ - if not isinstance(node.op, Subtensor): - return False - - shape = node.inputs[0] - - if not (shape.owner and isinstance(shape.owner.op, Shape)): - return False - - shape_arg = shape.owner.inputs[0] - - (idx,) = get_idx_list(node.inputs, node.op.idx_list) - - try: - idx_val = as_index_literal(idx) - except NotScalarConstantError: - return False - - assert idx_val != np.newaxis - - if not isinstance(shape_arg.type, TensorType): - return False - - shape_parts = shape_arg.type.broadcastable[idx_val] - - if isinstance(shape_parts, Iterable): - if all(shape_parts): - return [as_tensor([1] * len(shape_parts), dtype=np.int64, ndim=1)] - elif shape_parts: - return [as_tensor(1, dtype=np.int64)] - - -@register_canonicalize -@node_rewriter([Subtensor]) -def local_subtensor_SpecifyShape_lift(fgraph, node): - """Lift ``specify_shape(x, s)[i_1, ..., i_n]`` to ``specify_shape(x[i1, ... , i_n], s[n:])``.""" - - if not isinstance(node.op, Subtensor): - return False - - specify_shape_node = node.inputs[0] - - if not ( - specify_shape_node.owner - and isinstance(specify_shape_node.owner.op, SpecifyShape) - ): - return False - - obj_arg = specify_shape_node.owner.inputs[0] - shape_arg = specify_shape_node.owner.inputs[1:] - - indices = get_idx_list(node.inputs, node.op.idx_list) - - if any( - isinstance(index, slice) or isinstance(getattr(index, "type", None), SliceType) - for index in indices - ): - return False - - new_obj_arg = obj_arg[indices] - # No need to specify shape for scalar outputs - if new_obj_arg.ndim == 0: - return [new_obj_arg] - return [specify_shape(new_obj_arg, shape_arg[len(indices) :])] - - @register_specialize @node_rewriter([Join]) def local_join_subtensors(fgraph, node): diff --git a/pytensor/tensor/rewriting/subtensor_lift.py b/pytensor/tensor/rewriting/subtensor_lift.py new file mode 100644 index 0000000000..839d43f53f --- /dev/null +++ b/pytensor/tensor/rewriting/subtensor_lift.py @@ -0,0 +1,448 @@ +from collections.abc import Iterable + +import numpy as np + +from pytensor import Variable +from pytensor.graph import Constant, node_rewriter +from pytensor.graph.rewriting.basic import copy_stack_trace +from pytensor.scalar import basic as ps +from pytensor.tensor.basic import ( + Alloc, + MakeVector, + alloc, + as_tensor, + get_underlying_scalar_constant_value, + register_infer_shape, +) +from pytensor.tensor.elemwise import Elemwise +from pytensor.tensor.exceptions import NotScalarConstantError +from pytensor.tensor.math import Dot, ceil_intdiv, dot +from pytensor.tensor.rewriting.basic import ( + register_canonicalize, + register_specialize, + register_stabilize, +) +from pytensor.tensor.rewriting.subtensor import register_useless +from pytensor.tensor.shape import ( + Shape, + SpecifyShape, + Unbroadcast, + specify_shape, + unbroadcast, +) +from pytensor.tensor.subtensor import ( + AdvancedSubtensor1, + Subtensor, + as_index_literal, + get_canonical_form_slice, + get_constant_idx, + get_idx_list, +) +from pytensor.tensor.type import TensorType +from pytensor.tensor.type_other import SliceType + + +@register_canonicalize +@register_stabilize +@register_specialize +@node_rewriter([Subtensor]) +def local_subtensor_of_dot(fgraph, node): + """Rewrite ``at.dot(A, B)[idxs]`` into ``at.dot(A[idxs_a], B[idxs_b])``. + ``idxs_a`` is the first ``A.ndim-1`` entries of ``idxs``, and ``idxs_b`` is + the remaining entries of ``idxs`` (if any), modified to skip the + second-to-last dimension of ``B`` (because dot sums over this dimension). + """ + if not isinstance(node.op, Subtensor): + return + if not (node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Dot)): + return + # If there is other node that use the outputs of the dot + # We don't want to compute twice the sub part. + if len(fgraph.clients[node.inputs[0]]) > 1: + return + + a = node.inputs[0].owner.inputs[0] + b = node.inputs[0].owner.inputs[1] + + idx_list = get_idx_list(node.inputs, node.op.idx_list) + + num_a_indices = min(a.ndim - 1, len(idx_list)) + a_indices = idx_list[:num_a_indices] + b_indices = idx_list[num_a_indices:] + + # This is necessary because np.dot sums the last index of a with the second to last of b + # so we want to skip the second-to-last index into b. + # This wasn't necessary for a, because we just omitted the last index. + # We skip this if b.ndim = 1, since then we just want b_sub = b, not b_sub = b[:] + # (dot also handles b.ndim < 2 as a special case) + if b.ndim > 1 and len(b_indices) >= b.ndim - 1: + b_indices = ( + b_indices[: b.ndim - 2] + + (slice(None, None, None),) + + b_indices[b.ndim - 2 :] + ) + + a_sub = a.__getitem__(tuple(a_indices)) + b_sub = b.__getitem__(tuple(b_indices)) if b_indices else b + + # Copy over previous output stacktrace to a_sub and b_sub, + # because an error in the subtensor operation (e.g. an index error) + # on either a or b must correspond to an error in the + # subtensor operation on their dot product. + copy_stack_trace(node.outputs[0], [a_sub, b_sub]) + + # Copy over previous output stacktrace and previous dot product stacktrace, + # because an error here may correspond to an either in either the original + # dot product, or in the dot product after the subtensor operation. + r = dot(a_sub, b_sub) + copy_stack_trace([node.outputs[0], node.inputs[0]], r) + + return [r] + + +# fast_compile to allow opt subtensor(cast{float32}(make_vector)) +@register_canonicalize("fast_compile") +@node_rewriter([Subtensor]) +def local_subtensor_lift(fgraph, node): + """ + unary(x)[idx] -> unary(x[idx])#any broadcast pattern. + + Handles the following unary ops: + elemwise(x,...)[idx] -> elemwise(x[idx],...) + when x,... are broadcasted scalar or not broadcasted at all + Unbroadcast(x)[idx] => Unbroadcast(x[idx]) + + """ + if isinstance(node.op, Subtensor): + u = node.inputs[0] + if u.owner is None or len(fgraph.clients[u]) > 1: + return False + + if isinstance(u.owner.op, Elemwise) and len(u.owner.inputs) == 1: + idx = node.inputs[1:] + x_idx = node.op(u.owner.inputs[0], *idx) + # Copy over previous output stacktrace + copy_stack_trace(node.outputs, x_idx) + ret = u.owner.op(x_idx) + # Copy over previous output stacktrace + # and stacktrace from previous unary operation + copy_stack_trace([node.outputs[0], node.inputs[0]], ret) + return [ret] + + if isinstance(u.owner.op, Elemwise): + new_inputs = [] + if all(sum(i.type.broadcastable) == 0 for i in u.owner.inputs): + # There is no broadcastable in the inputs + idx = node.inputs[1:] + new_inputs = [node.op(i, *idx) for i in u.owner.inputs] + # Copy over previous output stacktrace + copy_stack_trace(node.outputs[0], new_inputs) + + ret = u.owner.op(*new_inputs) + # Copy over previous output stacktrace + # and stacktrace from previous unary operation + copy_stack_trace([node.outputs[0], node.inputs[0]], ret) + return [ret] + elif all(sum(i.type.broadcastable) in [i.ndim, 0] for i in u.owner.inputs): + # There is no broadcastable in the inputs or it is scalar + idx = node.inputs[1:] + new_inputs = [] + for i in u.owner.inputs: + if sum(i.type.broadcastable) == 0: + new_inputs.append(node.op(i, *idx)) + else: + # If the subtensor remove some dims, we must + # lower the number of dimensions of this scalar. + if node.outputs[0].ndim == i.ndim: + new_inputs.append(i) + else: + new_inputs.append( + i.dimshuffle(["x"] * node.outputs[0].ndim) + ) + + # Copy over previous output stacktrace + copy_stack_trace(node.outputs[0], new_inputs) + + ret = u.owner.op(*new_inputs) + # Copy over previous output stacktrace + # and stacktrace from previous unary operation + copy_stack_trace([node.outputs[0], node.inputs[0]], ret) + return [ret] + + if isinstance(u.owner.op, Unbroadcast): + # Subtensor might reduce dim., adapt broadcast pattern accordingly + old_axes = u.owner.op.axes + new_axes = [] + + # loop through indices being subtensor-ed + # i indexes broadcastable pattern before subtensor + # j indexes broadcastable pattern after subtensor + j = 0 + for i, x in enumerate(node.op.idx_list): + # if it is not a slice, it will reduce the dimension, should + # not appear in the broascastable dimensions + if isinstance(x, slice): + if i in old_axes: + new_axes.append(j) + j += 1 + # now keep the broadcastable pattern of all + # items not appearing in subtensor list + for i in range(len(node.op.idx_list), len(u.broadcastable)): + if i in old_axes: + new_axes.append(j) + j += 1 + + subt_x = node.op(u.owner.inputs[0], *node.inputs[1:]) + # Copy over previous output stacktrace + copy_stack_trace(node.outputs[0], subt_x) + + rbcast_subt_x = unbroadcast(subt_x, *new_axes) + # Copy over previous output stacktrace + # and stacktrace from previous unary operation + copy_stack_trace([node.outputs[0], node.inputs[0]], rbcast_subt_x) + + return [rbcast_subt_x] + + +@register_infer_shape +@register_useless +@register_canonicalize +@register_specialize +@node_rewriter([Subtensor]) +def local_subtensor_of_alloc(fgraph, node): + """ + + alloc(val)[x:y] -> alloc(val[...]) + alloc(val)[x:y] -> alloc(val) + This can be seen as a lift, but it also reduce the number of computation/memory. + + """ + if not isinstance(node.op, Subtensor): + return False + u = node.inputs[0] + if u.owner is None: + return False + if not isinstance(u.owner.op, Alloc): + return False + slices = get_idx_list(node.inputs, node.op.idx_list) + val = u.owner.inputs[0] + dims = u.owner.inputs[1:] + assert len(slices) <= len(dims) + + # Number of dimensions added to val + n_added_dims = u.ndim - val.ndim + # Dimensions of the returned alloc + nw_dims = [] + # Slices to take from val + val_slices = [] + + for i, (sl, dim) in enumerate(zip(slices, dims, strict=False)): + # If val was not copied over that dim, + # we need to take the appropriate subtensor on it. + if i >= n_added_dims: + # We check that the corresponding val dimensions was + # not a broadcasted dimensions. + if ( + val.type.ndim > (i - n_added_dims) + and val.type.broadcastable[i - n_added_dims] + ): + val_slices.append(slice(None)) + else: + val_slices.append(sl) + + csl, _ = get_canonical_form_slice(sl, dim) + if type(csl) is not slice: + # That dimension is removed. + pass + else: + nw_dim = csl.stop - csl.start + + if csl.step != 1: + # Do not add the ceil_intdiv() graphs in the graphs + # when this is not needed as it prevent detecting the + # correct broadcast pattern. + nw_dim = ceil_intdiv(nw_dim, csl.step) + nw_dims += [nw_dim] + + nw_val = val[tuple(val_slices)] + nw_dims += dims[len(slices) :] + if nw_val.ndim > len(nw_dims): + return False + rval = alloc(nw_val, *nw_dims) + if not isinstance(rval, list | tuple): + rval = [rval] + return rval + + +@register_canonicalize +@node_rewriter([Subtensor]) +def local_subtensor_SpecifyShape_lift(fgraph, node): + """Lift ``specify_shape(x, s)[i_1, ..., i_n]`` to ``specify_shape(x[i1, ... , i_n], s[n:])``.""" + + if not isinstance(node.op, Subtensor): + return False + + specify_shape_node = node.inputs[0] + + if not ( + specify_shape_node.owner + and isinstance(specify_shape_node.owner.op, SpecifyShape) + ): + return False + + obj_arg = specify_shape_node.owner.inputs[0] + shape_arg = specify_shape_node.owner.inputs[1:] + + indices = get_idx_list(node.inputs, node.op.idx_list) + + if any( + isinstance(index, slice) or isinstance(getattr(index, "type", None), SliceType) + for index in indices + ): + return False + + new_obj_arg = obj_arg[indices] + # No need to specify shape for scalar outputs + if new_obj_arg.ndim == 0: + return [new_obj_arg] + return [specify_shape(new_obj_arg, shape_arg[len(indices) :])] + + +@register_infer_shape +@register_specialize +@register_canonicalize("fast_compile") +@register_useless +@node_rewriter([Subtensor, AdvancedSubtensor1]) +def local_subtensor_make_vector(fgraph, node): + """Perform ``*Subtensor*`` operations on ``MakeVector`` outputs when the indices are constant. + + Replace all ``Subtensor`` and ``MakeVector`` cases like: + [a,b,c][0] -> a + [a,b,c][0:2] -> [a,b] + + Replace all ``AdvancedSubtensor1`` and ``MakeVector`` cases like: + [a,b,c][[0,2]] -> [a,c] + + We can do this for constant indexes. + + .. note: + + This optimization implicitly relies on shape optimizations. + + TODO: This only applies to a single indexed dimension; we should have + something more general for constant ``*Subtensor*`` graphs (or perhaps + include this kind of work in the constant folding). + """ + + if not isinstance(node.op, Subtensor | AdvancedSubtensor1): + return False + + x = node.inputs[0] + + if not (x.owner and isinstance(x.owner.op, MakeVector)): + return False + + make_vector_op = x.owner.op + + if isinstance(node.op, Subtensor): + idxs = node.op.idx_list + + # Subtensor has no indexes, return make_vector + if not idxs: + return [x] + + (idx,) = idxs + + if isinstance(idx, ps.ScalarType | TensorType): + old_idx, idx = idx, node.inputs[1] + assert idx.type.is_super(old_idx) + elif isinstance(node.op, AdvancedSubtensor1): + idx = node.inputs[1] + + if isinstance(idx, int | np.integer): + return [x.owner.inputs[idx]] + elif isinstance(idx, Variable): + if idx.ndim == 0: + try: + v = get_underlying_scalar_constant_value( + idx, only_process_constants=True + ) + try: + ret = [x.owner.inputs[v]] + except IndexError: + raise NotScalarConstantError("Bad user graph!") + return ret + except NotScalarConstantError: + pass + elif idx.ndim == 1 and isinstance(idx, Constant): + values = list(map(int, list(idx.value))) + ret = make_vector_op(*[x.owner.inputs[v] for v in values]) + copy_stack_trace(node.outputs[0], ret) + return [ret] + elif isinstance(idx, slice): + # The index is a slice. If it's a constant slice, we can perform the + # index operation here. + try: + const_slice = get_constant_idx( + node.op.idx_list, node.inputs, allow_partial=False + )[0] + ret = make_vector_op(*x.owner.inputs[const_slice]) + copy_stack_trace(node.outputs, ret) + return [ret] + except NotScalarConstantError: + pass + + +@register_specialize +@register_canonicalize +@node_rewriter([Subtensor]) +def local_subtensor_shape_constant(fgraph, node): + r"""Simplify constant `Subtensor`\s on `Shape`\s dimensions that are known. + + We want to convert graphs like + + Subtensor{int64} [id A] '' + |Shape [id B] '' + | | [id C] + |ScalarConstant{0} [id D] + + into + + TensorConstant{1} + + TODO: Something like `local_shape_to_shape_i` should be a general + canonicalization, and not a `ShapeFeature`-dependent rewrite. If that were + the case, we could change this to only operate on `Shape_i`\s. + Currently, we're not handling them because they should only appear when + `ShapeFeature` is present, and it will also simplify/remove them. + + """ + if not isinstance(node.op, Subtensor): + return False + + shape = node.inputs[0] + + if not (shape.owner and isinstance(shape.owner.op, Shape)): + return False + + shape_arg = shape.owner.inputs[0] + + (idx,) = get_idx_list(node.inputs, node.op.idx_list) + + try: + idx_val = as_index_literal(idx) + except NotScalarConstantError: + return False + + assert idx_val != np.newaxis + + if not isinstance(shape_arg.type, TensorType): + return False + + shape_parts = shape_arg.type.broadcastable[idx_val] + + if isinstance(shape_parts, Iterable): + if all(shape_parts): + return [as_tensor([1] * len(shape_parts), dtype=np.int64, ndim=1)] + elif shape_parts: + return [as_tensor(1, dtype=np.int64)] diff --git a/tests/tensor/rewriting/test_subtensor.py b/tests/tensor/rewriting/test_subtensor.py index fcfd72ddf2..c7c05e5291 100644 --- a/tests/tensor/rewriting/test_subtensor.py +++ b/tests/tensor/rewriting/test_subtensor.py @@ -9,28 +9,19 @@ from pytensor.compile.mode import Mode, get_default_mode, get_mode from pytensor.compile.ops import DeepCopyOp from pytensor.configdefaults import config -from pytensor.graph import FunctionGraph, vectorize_graph +from pytensor.graph import vectorize_graph from pytensor.graph.basic import Constant, Variable, ancestors, equal_computations from pytensor.graph.rewriting.basic import check_stack_trace -from pytensor.graph.rewriting.db import RewriteDatabaseQuery -from pytensor.graph.rewriting.utils import rewrite_graph -from pytensor.graph.type import Type from pytensor.raise_op import Assert -from pytensor.tensor import inplace -from pytensor.tensor.basic import Alloc, MakeVector, _convert_to_int8, make_vector +from pytensor.tensor.basic import Alloc, _convert_to_int8 from pytensor.tensor.blockwise import Blockwise -from pytensor.tensor.elemwise import DimShuffle, Elemwise -from pytensor.tensor.math import Dot, add, dot, exp, sqr +from pytensor.tensor.elemwise import Elemwise +from pytensor.tensor.math import Dot, dot, exp, sqr from pytensor.tensor.rewriting.subtensor import ( local_replace_AdvancedSubtensor, - local_subtensor_make_vector, - local_subtensor_shape_constant, ) from pytensor.tensor.shape import ( SpecifyShape, - Unbroadcast, - _shape, - shape, specify_shape, ) from pytensor.tensor.subtensor import ( @@ -50,19 +41,15 @@ dmatrix, fmatrix, iscalar, - iscalars, ivector, - lscalar, - lscalars, matrix, - row, scalar, tensor, tensor3, tensor4, vector, ) -from pytensor.tensor.type_other import make_slice, slicetype +from pytensor.tensor.type_other import make_slice from tests import unittest_tools as utt from tests.unittest_tools import create_pytensor_param @@ -666,320 +653,6 @@ def test_different_dtypes(self): assert np.array_equal(f(x_, i_, v_), v_.astype("int8")) -class TestLocalSubtensorMakeVector: - mode = get_mode("FAST_RUN").including("local_subtensor_make_vector") - - def test_scalar_idx(self): - x, y, z = lscalars("xyz") - v = make_vector(x, y, z) - f = function([x, y, z], v[0], mode=self.mode) - - prog = f.maker.fgraph.toposort() - assert len(prog) == 1 - assert isinstance(prog[0].op, DeepCopyOp) - assert f(0, 1, 2) == 0 - - def test_idx_symbolic(self): - x, y, z = iscalars("xyz") - v = MakeVector("int32")(x, y, z) - idx = pt.as_tensor([0], dtype=np.int64) - f = function([x, y, z], v[idx], mode=self.mode) - - opt_fgraph = f.maker.fgraph - assert opt_fgraph.outputs[0].dtype == "int32" - assert isinstance(opt_fgraph.outputs[0].owner.op, MakeVector) - assert f(0, 1, 2) == np.array([0], dtype=np.int32) - - def test_slice_idx_start(self): - x, y, z = iscalars("xyz") - v = MakeVector("int32")(x, y, z) - f = function([x, y, z], v[1:], mode=self.mode, on_unused_input="ignore") - - opt_fgraph = f.maker.fgraph - assert opt_fgraph.outputs[0].dtype == "int32" - assert isinstance(opt_fgraph.outputs[0].owner.op, MakeVector) - assert len(opt_fgraph.outputs[0].owner.inputs) == 2 - r = f(0, 1, 2) - assert r[0] == 1 and r[1] == 2 - - def test_slice_idx_stop(self): - x, y, z = lscalars("xyz") - v = make_vector(x, y, z) - f = function([x, y, z], v[:2], mode=self.mode) - - prog = f.maker.fgraph.toposort() - assert len(prog) == 1 - assert isinstance(prog[0].op, MakeVector) - assert len(prog[0].inputs) == 2 - r = f(0, 1, 2) - assert r[0] == 0 and r[1] == 1 - - def test_slice_idx_step(self): - x, y, z = lscalars("xyz") - v = make_vector(x, y, z) - f = function([x, y, z], v[::2], mode=self.mode) - - prog = f.maker.fgraph.toposort() - assert len(prog) == 1 - assert isinstance(prog[0].op, MakeVector) - assert len(prog[0].inputs) == 2 - r = f(0, 1, 2) - assert r[0] == 0 and r[1] == 2 - - def test_AdvancedSubtensor1_idx(self): - x, y, z = lscalars("xyz") - v = make_vector(x, y, z) - f = function([x, y, z], v[[0, 2]], mode=self.mode) - - prog = f.maker.fgraph.toposort() - assert len(prog) == 1 - assert isinstance(prog[0].op, MakeVector) - assert len(prog[0].inputs) == 2 - r = f(0, 1, 2) - assert r[0] == 0 and r[1] == 2 - - def test_MakeVector_idx(self): - x, y, z, q = lscalars("xyzq") - v = make_vector(x, y, z) - q = make_vector(0, 2) - f = function([x, y, z], v[q], mode=self.mode) - - prog = f.maker.fgraph.toposort() - assert len(prog) == 1 - assert isinstance(prog[0].op, MakeVector) - assert len(prog[0].inputs) == 2 - r = f(0, 1, 2) - assert r[0] == 0 and r[1] == 2 - - def test_stack_trace(self): - x, y, z = lscalars("xyz") - v = make_vector(x, y, z) - - mode = get_default_mode().including("local_subtensor_make_vector") - - # list of subtensor cases, where local_subtensor_make_vector - # inserts a new MakeVector node - v_subtensors = [v[:2], v[::2], v[[0, 2]]] - - for v_subtensor in v_subtensors: - f = function([x, y, z], v_subtensor, mode=mode) - assert check_stack_trace(f, ops_to_check="all") - - def test_empty_subtensor(self): - x, y = lscalars("xy") - v = make_vector(x, y) - out = v[()] - - fgraph = FunctionGraph(outputs=[out], clone=False) - node = fgraph.outputs[0].owner - assert isinstance(node.op, Subtensor) - - assert local_subtensor_make_vector.transform(fgraph, node) == [v] - - -class TestLocalSubtensorLift: - def test_basic(self): - # basic test that the Op works - x = matrix("x") - f = function([x], exp(x)[0], mode=mode_opt) - - # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(f, ops_to_check="all") - - prog = f.maker.fgraph.toposort() - assert isinstance(prog[0].op, Subtensor) # first subtensor - assert prog[1].op == exp - assert len(prog) == 2 - f([[0, 1], [2, 3]]) # let debugmode test something - - def test_basic_1(self): - # as test0, but we reuse the output of the elemwise - # So we should not lift the subtensor - x = matrix("x") - f = function([x], [exp(x)[0], exp(x)], mode=mode_opt) - - # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(f, ops_to_check=[Subtensor, Elemwise]) - - prog = f.maker.fgraph.toposort() - assert prog[0].op == exp - assert isinstance(prog[1].op, Subtensor) # first subtensor - assert isinstance(prog[2].op, DeepCopyOp) - assert len(prog) == 3 - f([[0, 1], [2, 3]]) # let debugmode test something - - def test_basic_2(self): - # basic test that the optimization work with scalar broadcasted - x = matrix("x") - y = scalar("y") - z = matrix("z") - f = function([x, y, z], exp(x + y + z)[0], mode=mode_opt) - - prog = f.maker.fgraph.toposort() - assert isinstance(prog[0].op, Subtensor) - assert isinstance(prog[1].op, DimShuffle) - assert isinstance(prog[2].op, Subtensor) - assert isinstance(prog[3].op.scalar_op, ps.Composite) # Composite{add,add} - assert len(prog) == 4 - - # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(f, ops_to_check=[Subtensor]) - - # let debugmode test something - f([[0, 1], [2, 3]], 4, [[4, 5], [6, 7]]) - - def test_basic_3(self): - # as 1, but take a slice - x = matrix("x") - y = scalar("y") - z = matrix("z") - f = function([x, y, z], exp(x + y + z)[0:2], mode=mode_opt) - - prog = f.maker.fgraph.toposort() - assert isinstance(prog[0].op, Subtensor) - assert isinstance(prog[1].op, DimShuffle) - assert isinstance(prog[2].op, Subtensor) - assert isinstance(prog[3].op.scalar_op, ps.Composite) # Composite{add,add} - assert len(prog) == 4 - - # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(f, ops_to_check=[Subtensor]) - - # let debugmode test something - f([[0, 1], [2, 3]], 4, [[4, 5], [6, 7]]) - - def test_basic_4(self): - # basic test that the optimization does work with broadcasting - # for unary elemwise. - y = vector("y") - f = function([y], exp(y.dimshuffle(0, "x"))[0], mode=mode_opt) - - # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(f, ops_to_check="all") - - prog = f.maker.fgraph.toposort() - assert isinstance(prog[0].op, DimShuffle) - assert isinstance(prog[1].op, Subtensor) - assert prog[2].op == exp - assert len(prog) == 3 - f([4, 5]) # let debugmode test something - - @utt.assertFailure_fast - def test_basic_5(self): - # basic test that the optimization doesn't work with broadcasting - # ... It *could* be extended to, - # ... but right now it doesn't, so it shouldn't try. - x = matrix("x") - y = vector("y") - f = function([x, y], exp(x + y)[0], mode=mode_opt) - - # Opt doesn't apply, so no need for check_stack_trace - # assert check_stack_trace(f, ops_to_check='all') - - prog = f.maker.fgraph.toposort() - assert isinstance(prog[0].op, DimShuffle) - assert prog[1].op == add - assert isinstance(prog[2].op, Subtensor) # first subtensor - assert prog[3].op == inplace.exp_inplace - assert len(prog) == 4 - f([[0, 1], [2, 3]], [4, 5]) # let debugmode test something - - def test_basic_6(self): - # test that we don't lift when we reuse the output of the - # elemwise for other computation. - x = matrix("x") - y = vector("y") - f = function([x, y], [exp(x + y)[0], exp(x + y) + x], mode=mode_opt) - - # Opt doesn't apply, so no need for check_stack_trace - # assert check_stack_trace(f, ops_to_check=Subtensor) - - prog = f.maker.fgraph.toposort() - assert isinstance(prog[0].op, DimShuffle) - assert isinstance(prog[1].op.scalar_op, ps.Composite) # Composite{add,exp} - # first subtensor - assert isinstance(prog[2].op, Subtensor) - assert len(prog) == 3 - f([[0, 1], [2, 3]], [4, 5]) # let debugmode test something - - def test_basic_7(self): - # basic test that the optimization works with a scalar as input, - # and a scalar as output (no broadcasting of the scalar needed). - # The optimization used to fail and display an ERROR message. - - x = vector("x") - y = scalar("y") - f = function([x, y], exp(x + y)[0], mode=mode_opt) - - # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(f, ops_to_check=Subtensor) - - prog = f.maker.fgraph.toposort() - assert isinstance(prog[0].op, Subtensor) - # Composite{add,exp} - assert isinstance(prog[1].op.scalar_op, ps.Composite) - assert len(prog) == 2 - f([1, 2, 3], 4) # let debugmode test something - - def test_basic_8(self): - # Test that Subtensor(Unbroadcast(x)) gets optimized into - # Unbroadcast(Subtensor(x)). - - # test basic case - x = row("x") - xval = np.random.random((1, 10)).astype(config.floatX) - assert x.broadcastable == (True, False) - newx = Unbroadcast(0)(x) - assert newx.broadcastable == (False, False) - - f1 = function([x], newx[:2, :5], mode=mode_opt) - # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(f1, ops_to_check=[Subtensor, Unbroadcast]) - prog = f1.maker.fgraph.toposort() - assert isinstance(prog[0].op, Subtensor) - assert isinstance(prog[1].op, Unbroadcast) - assert (f1(xval) == xval[:2, :5]).all() - - # corner case 1: Unbroadcast changes dims which are dropped through subtensor - y = tensor(dtype="float64", shape=(1, 10, 1, 3), name="x") - yval = np.random.random((1, 10, 1, 3)).astype(config.floatX) - assert y.broadcastable == (True, False, True, False) - newy = Unbroadcast(0, 2)(y) - assert newy.broadcastable == (False, False, False, False) - - f2 = function([y], newy[:, 3, 0, :], mode=mode_opt) - # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(f2, ops_to_check=[Subtensor, Unbroadcast]) - prog = f2.maker.fgraph.toposort() - assert isinstance(prog[0].op, Subtensor) - assert isinstance(prog[1].op, Unbroadcast) - assert (f2(yval) == yval[:, 3, 0, :]).all() - - # corner case 2: subtensor idx_list is shorter than resulting broadcast pattern - f3 = function([y], newy[:, 3, 0], mode=mode_opt) - # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(f3, ops_to_check=[Subtensor, Unbroadcast]) - prog = f3.maker.fgraph.toposort() - assert isinstance(prog[0].op, Subtensor) - assert isinstance(prog[1].op, Unbroadcast) - assert (f3(yval) == yval[:, 3, 0]).all() - - # corner case 3: subtensor idx_list is shorter than Unbroadcast.axis - z = tensor(dtype="float64", shape=(4, 10, 3, 1), name="x") - zval = np.random.random((4, 10, 3, 1)).astype(config.floatX) - assert z.broadcastable == (False, False, False, True) - newz = Unbroadcast(3)(z) - assert newz.broadcastable == (False, False, False, False) - - f4 = function([z], newz[:, 3, 0], mode=mode_opt) - # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(f4, ops_to_check=[Subtensor, Unbroadcast]) - prog = f4.maker.fgraph.toposort() - assert isinstance(prog[0].op, Subtensor) - assert isinstance(prog[1].op, Unbroadcast) - assert (f4(zval) == zval[:, 3, 0]).all() - - class TestLocalSubtensorMerge: def setup_method(self): self.x_shapes = [(2, 2), (5, 3), (4, 1), (1, 2), (0, 2), (2, 0), (1, 0), (0, 0)] @@ -1863,200 +1536,6 @@ def test_local_set_to_inc_subtensor(): assert check_stack_trace(f2, ops_to_check="all") -def test_local_subtensor_of_alloc(): - # DebugMode should detect if something goes wrong. - # test shape combination of odd and event shape. - for s in [(3, 5), (4, 6), (3, 8), (4, 7), (1, 5), (5, 1)]: - x = tensor( - dtype=config.floatX, - shape=(1 if s[0] == 1 else None, 1 if s[1] == 1 else None), - ) - - xval = np.zeros(s, dtype=config.floatX) - yval = np.arange(s[1], dtype=config.floatX) - - for y in [shared(yval), pt.constant([1.0])]: - # The rows of yx are copies of y - yx = pt.alloc(y, x.shape[0], x.shape[1]) - - # Slice of each row - z_mat = yx[:, 3:] - assert z_mat.ndim == 2 - - # Only one column - z_vec = yx[:, 3] - assert z_vec.ndim == 1 - # results are vector - slicess = [] - if s[0] != 1: - slicess.append((2, slice(None))) - if s[1] != 1: - slicess.append((slice(None), 3)) - - # results are matrix - slicess += [ - (slice(None), slice(3, None)), - (slice(3, None),), - (slice(3, None), slice(3, None)), - (slice(1, 3), slice(None, -1)), - (slice(None, None, 2)), - (slice(1, None, 2)), - ] - for slices in slicess: - z = yx.__getitem__(slices) - f = function([x], z) - if config.mode != "FAST_COMPILE": - # Subtensor can be in the input of Alloc - assert not isinstance(f.maker.fgraph.toposort()[-1].op, Subtensor) - val = f(xval) - assert xval.__getitem__(slices).shape == val.shape - - -def test_local_subtensor_shape_constant(): - x = tensor(dtype=np.float64, shape=(1, None)).shape[0] - (res,) = local_subtensor_shape_constant.transform(None, x.owner) - assert isinstance(res, Constant) - assert res.data == 1 - - # Make sure it's part of the canonicalizations - res = rewrite_graph(x) - assert isinstance(res, Constant) - assert res.data == 1 - - x = _shape(tensor(dtype=np.float64, shape=(1, None)))[lscalar()] - assert not local_subtensor_shape_constant.transform(None, x.owner) - - x = _shape(tensor(dtype=np.float64, shape=(1, None)))[0:] - assert not local_subtensor_shape_constant.transform(None, x.owner) - - x = _shape(tensor(dtype=np.float64, shape=(1, None)))[lscalar() :] - assert not local_subtensor_shape_constant.transform(None, x.owner) - - x = _shape(tensor(dtype=np.float64, shape=(1, 1)))[1:] - (res,) = local_subtensor_shape_constant.transform(None, x.owner) - assert isinstance(res, Constant) - assert np.array_equal(res.data, [1]) - - x = _shape(tensor(dtype=np.float64, shape=(None, 1, 1)))[1:] - (res,) = local_subtensor_shape_constant.transform(None, x.owner) - assert isinstance(res, Constant) - assert np.array_equal(res.data, [1, 1]) - - # A test for a non-`TensorType` - class MyType(Type): - def filter(self, *args, **kwargs): - raise NotImplementedError() - - def __eq__(self, other): - return isinstance(other, MyType) and other.thingy == self.thingy - - x = shape(Variable(MyType(), None, None))[0] - - assert not local_subtensor_shape_constant.transform(None, x.owner) - - -@pytest.mark.parametrize( - "x, s, idx, x_val, s_val", - [ - ( - vector(), - (iscalar(),), - (1,), - np.array([1, 2], dtype=config.floatX), - np.array([2], dtype=np.int64), - ), - ( - matrix(), - (iscalar(), iscalar()), - (1,), - np.array([[1, 2], [3, 4]], dtype=config.floatX), - np.array([2, 2], dtype=np.int64), - ), - ( - matrix(), - (iscalar(), iscalar()), - (0,), - np.array([[1, 2, 3], [4, 5, 6]], dtype=config.floatX), - np.array([2, 3], dtype=np.int64), - ), - ( - matrix(), - (iscalar(), iscalar()), - (1, 1), - np.array([[1, 2, 3], [4, 5, 6]], dtype=config.floatX), - np.array([2, 3], dtype=np.int64), - ), - ( - tensor3(), - (iscalar(), iscalar(), iscalar()), - (-1,), - np.arange(2 * 3 * 5, dtype=config.floatX).reshape((2, 3, 5)), - np.array([2, 3, 5], dtype=np.int64), - ), - ( - tensor3(), - (iscalar(), iscalar(), iscalar()), - (-1, 0), - np.arange(2 * 3 * 5, dtype=config.floatX).reshape((2, 3, 5)), - np.array([2, 3, 5], dtype=np.int64), - ), - ], -) -def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val): - y = specify_shape(x, s)[idx] - assert isinstance(y.owner.inputs[0].owner.op, SpecifyShape) - - rewrites = RewriteDatabaseQuery(include=[None]) - no_rewrites_mode = Mode(optimizer=rewrites) - - y_val_fn = function([x, *s], y, on_unused_input="ignore", mode=no_rewrites_mode) - y_val = y_val_fn(*([x_val, *s_val])) - - # This optimization should appear in the canonicalizations - y_opt = rewrite_graph(y, clone=False) - - if y.ndim == 0: - # SpecifyShape should be removed altogether - assert isinstance(y_opt.owner.op, Subtensor) - assert y_opt.owner.inputs[0] is x - else: - assert isinstance(y_opt.owner.op, SpecifyShape) - - y_opt_fn = function([x, *s], y_opt, on_unused_input="ignore") - y_opt_val = y_opt_fn(*([x_val, *s_val])) - - assert np.allclose(y_val, y_opt_val) - - -@pytest.mark.parametrize( - "x, s, idx", - [ - ( - matrix(), - (iscalar(), iscalar()), - (slice(1, None),), - ), - ( - matrix(), - (iscalar(), iscalar()), - (slicetype(),), - ), - ( - matrix(), - (iscalar(), iscalar()), - (1, 0), - ), - ], -) -def test_local_subtensor_SpecifyShape_lift_fail(x, s, idx): - y = specify_shape(x, s)[idx] - - # This optimization should appear in the canonicalizations - y_opt = rewrite_graph(y, clone=False) - - assert not isinstance(y_opt.owner.op, SpecifyShape) - - @pytest.mark.parametrize( "axis, slices_fn, expected_nodes", [ diff --git a/tests/tensor/rewriting/test_subtensor_lift.py b/tests/tensor/rewriting/test_subtensor_lift.py new file mode 100644 index 0000000000..6fc30b8d0d --- /dev/null +++ b/tests/tensor/rewriting/test_subtensor_lift.py @@ -0,0 +1,561 @@ +import numpy as np +import pytest +import unittest_tools as utt +from tensor.rewriting.test_subtensor import mode_opt + +from pytensor import ( + Mode, + Variable, + config, + function, + shared, +) +from pytensor import ( + scalar as ps, +) +from pytensor import ( + tensor as pt, +) +from pytensor.compile import DeepCopyOp, get_default_mode, get_mode +from pytensor.graph import ( + Constant, + FunctionGraph, + RewriteDatabaseQuery, + Type, + rewrite_graph, +) +from pytensor.graph.rewriting.basic import check_stack_trace +from pytensor.tensor import ( + add, + exp, + inplace, + iscalar, + iscalars, + lscalar, + lscalars, + matrix, + row, + scalar, + shape, + slicetype, + specify_shape, + tensor, + tensor3, + vector, +) +from pytensor.tensor.basic import MakeVector, make_vector +from pytensor.tensor.elemwise import DimShuffle, Elemwise +from pytensor.tensor.rewriting.subtensor_lift import ( + local_subtensor_make_vector, + local_subtensor_shape_constant, +) +from pytensor.tensor.shape import SpecifyShape, Unbroadcast, _shape +from pytensor.tensor.subtensor import Subtensor + + +class TestLocalSubtensorLift: + def test_basic(self): + # basic test that the Op works + x = matrix("x") + f = function([x], exp(x)[0], mode=mode_opt) + + # Check stacktrace was copied over correctly after opt was applied + assert check_stack_trace(f, ops_to_check="all") + + prog = f.maker.fgraph.toposort() + assert isinstance(prog[0].op, Subtensor) # first subtensor + assert prog[1].op == exp + assert len(prog) == 2 + f([[0, 1], [2, 3]]) # let debugmode test something + + def test_basic_1(self): + # as test0, but we reuse the output of the elemwise + # So we should not lift the subtensor + x = matrix("x") + f = function([x], [exp(x)[0], exp(x)], mode=mode_opt) + + # Check stacktrace was copied over correctly after opt was applied + assert check_stack_trace(f, ops_to_check=[Subtensor, Elemwise]) + + prog = f.maker.fgraph.toposort() + assert prog[0].op == exp + assert isinstance(prog[1].op, Subtensor) # first subtensor + assert isinstance(prog[2].op, DeepCopyOp) + assert len(prog) == 3 + f([[0, 1], [2, 3]]) # let debugmode test something + + def test_basic_2(self): + # basic test that the optimization work with scalar broadcasted + x = matrix("x") + y = scalar("y") + z = matrix("z") + f = function([x, y, z], exp(x + y + z)[0], mode=mode_opt) + + prog = f.maker.fgraph.toposort() + assert isinstance(prog[0].op, Subtensor) + assert isinstance(prog[1].op, DimShuffle) + assert isinstance(prog[2].op, Subtensor) + assert isinstance(prog[3].op.scalar_op, ps.Composite) # Composite{add,add} + assert len(prog) == 4 + + # Check stacktrace was copied over correctly after opt was applied + assert check_stack_trace(f, ops_to_check=[Subtensor]) + + # let debugmode test something + f([[0, 1], [2, 3]], 4, [[4, 5], [6, 7]]) + + def test_basic_3(self): + # as 1, but take a slice + x = matrix("x") + y = scalar("y") + z = matrix("z") + f = function([x, y, z], exp(x + y + z)[0:2], mode=mode_opt) + + prog = f.maker.fgraph.toposort() + assert isinstance(prog[0].op, Subtensor) + assert isinstance(prog[1].op, DimShuffle) + assert isinstance(prog[2].op, Subtensor) + assert isinstance(prog[3].op.scalar_op, ps.Composite) # Composite{add,add} + assert len(prog) == 4 + + # Check stacktrace was copied over correctly after opt was applied + assert check_stack_trace(f, ops_to_check=[Subtensor]) + + # let debugmode test something + f([[0, 1], [2, 3]], 4, [[4, 5], [6, 7]]) + + def test_basic_4(self): + # basic test that the optimization does work with broadcasting + # for unary elemwise. + y = vector("y") + f = function([y], exp(y.dimshuffle(0, "x"))[0], mode=mode_opt) + + # Check stacktrace was copied over correctly after opt was applied + assert check_stack_trace(f, ops_to_check="all") + + prog = f.maker.fgraph.toposort() + assert isinstance(prog[0].op, DimShuffle) + assert isinstance(prog[1].op, Subtensor) + assert prog[2].op == exp + assert len(prog) == 3 + f([4, 5]) # let debugmode test something + + @utt.assertFailure_fast + def test_basic_5(self): + # basic test that the optimization doesn't work with broadcasting + # ... It *could* be extended to, + # ... but right now it doesn't, so it shouldn't try. + x = matrix("x") + y = vector("y") + f = function([x, y], exp(x + y)[0], mode=mode_opt) + + # Opt doesn't apply, so no need for check_stack_trace + # assert check_stack_trace(f, ops_to_check='all') + + prog = f.maker.fgraph.toposort() + assert isinstance(prog[0].op, DimShuffle) + assert prog[1].op == add + assert isinstance(prog[2].op, Subtensor) # first subtensor + assert prog[3].op == inplace.exp_inplace + assert len(prog) == 4 + f([[0, 1], [2, 3]], [4, 5]) # let debugmode test something + + def test_basic_6(self): + # test that we don't lift when we reuse the output of the + # elemwise for other computation. + x = matrix("x") + y = vector("y") + f = function([x, y], [exp(x + y)[0], exp(x + y) + x], mode=mode_opt) + + # Opt doesn't apply, so no need for check_stack_trace + # assert check_stack_trace(f, ops_to_check=Subtensor) + + prog = f.maker.fgraph.toposort() + assert isinstance(prog[0].op, DimShuffle) + assert isinstance(prog[1].op.scalar_op, ps.Composite) # Composite{add,exp} + # first subtensor + assert isinstance(prog[2].op, Subtensor) + assert len(prog) == 3 + f([[0, 1], [2, 3]], [4, 5]) # let debugmode test something + + def test_basic_7(self): + # basic test that the optimization works with a scalar as input, + # and a scalar as output (no broadcasting of the scalar needed). + # The optimization used to fail and display an ERROR message. + + x = vector("x") + y = scalar("y") + f = function([x, y], exp(x + y)[0], mode=mode_opt) + + # Check stacktrace was copied over correctly after opt was applied + assert check_stack_trace(f, ops_to_check=Subtensor) + + prog = f.maker.fgraph.toposort() + assert isinstance(prog[0].op, Subtensor) + # Composite{add,exp} + assert isinstance(prog[1].op.scalar_op, ps.Composite) + assert len(prog) == 2 + f([1, 2, 3], 4) # let debugmode test something + + def test_basic_8(self): + # Test that Subtensor(Unbroadcast(x)) gets optimized into + # Unbroadcast(Subtensor(x)). + + # test basic case + x = row("x") + xval = np.random.random((1, 10)).astype(config.floatX) + assert x.broadcastable == (True, False) + newx = Unbroadcast(0)(x) + assert newx.broadcastable == (False, False) + + f1 = function([x], newx[:2, :5], mode=mode_opt) + # Check stacktrace was copied over correctly after opt was applied + assert check_stack_trace(f1, ops_to_check=[Subtensor, Unbroadcast]) + prog = f1.maker.fgraph.toposort() + assert isinstance(prog[0].op, Subtensor) + assert isinstance(prog[1].op, Unbroadcast) + assert (f1(xval) == xval[:2, :5]).all() + + # corner case 1: Unbroadcast changes dims which are dropped through subtensor + y = tensor(dtype="float64", shape=(1, 10, 1, 3), name="x") + yval = np.random.random((1, 10, 1, 3)).astype(config.floatX) + assert y.broadcastable == (True, False, True, False) + newy = Unbroadcast(0, 2)(y) + assert newy.broadcastable == (False, False, False, False) + + f2 = function([y], newy[:, 3, 0, :], mode=mode_opt) + # Check stacktrace was copied over correctly after opt was applied + assert check_stack_trace(f2, ops_to_check=[Subtensor, Unbroadcast]) + prog = f2.maker.fgraph.toposort() + assert isinstance(prog[0].op, Subtensor) + assert isinstance(prog[1].op, Unbroadcast) + assert (f2(yval) == yval[:, 3, 0, :]).all() + + # corner case 2: subtensor idx_list is shorter than resulting broadcast pattern + f3 = function([y], newy[:, 3, 0], mode=mode_opt) + # Check stacktrace was copied over correctly after opt was applied + assert check_stack_trace(f3, ops_to_check=[Subtensor, Unbroadcast]) + prog = f3.maker.fgraph.toposort() + assert isinstance(prog[0].op, Subtensor) + assert isinstance(prog[1].op, Unbroadcast) + assert (f3(yval) == yval[:, 3, 0]).all() + + # corner case 3: subtensor idx_list is shorter than Unbroadcast.axis + z = tensor(dtype="float64", shape=(4, 10, 3, 1), name="x") + zval = np.random.random((4, 10, 3, 1)).astype(config.floatX) + assert z.broadcastable == (False, False, False, True) + newz = Unbroadcast(3)(z) + assert newz.broadcastable == (False, False, False, False) + + f4 = function([z], newz[:, 3, 0], mode=mode_opt) + # Check stacktrace was copied over correctly after opt was applied + assert check_stack_trace(f4, ops_to_check=[Subtensor, Unbroadcast]) + prog = f4.maker.fgraph.toposort() + assert isinstance(prog[0].op, Subtensor) + assert isinstance(prog[1].op, Unbroadcast) + assert (f4(zval) == zval[:, 3, 0]).all() + + +def test_local_subtensor_of_alloc(): + # DebugMode should detect if something goes wrong. + # test shape combination of odd and event shape. + for s in [(3, 5), (4, 6), (3, 8), (4, 7), (1, 5), (5, 1)]: + x = tensor( + dtype=config.floatX, + shape=(1 if s[0] == 1 else None, 1 if s[1] == 1 else None), + ) + + xval = np.zeros(s, dtype=config.floatX) + yval = np.arange(s[1], dtype=config.floatX) + + for y in [shared(yval), pt.constant([1.0])]: + # The rows of yx are copies of y + yx = pt.alloc(y, x.shape[0], x.shape[1]) + + # Slice of each row + z_mat = yx[:, 3:] + assert z_mat.ndim == 2 + + # Only one column + z_vec = yx[:, 3] + assert z_vec.ndim == 1 + # results are vector + slicess = [] + if s[0] != 1: + slicess.append((2, slice(None))) + if s[1] != 1: + slicess.append((slice(None), 3)) + + # results are matrix + slicess += [ + (slice(None), slice(3, None)), + (slice(3, None),), + (slice(3, None), slice(3, None)), + (slice(1, 3), slice(None, -1)), + (slice(None, None, 2)), + (slice(1, None, 2)), + ] + for slices in slicess: + z = yx.__getitem__(slices) + f = function([x], z) + if config.mode != "FAST_COMPILE": + # Subtensor can be in the input of Alloc + assert not isinstance(f.maker.fgraph.toposort()[-1].op, Subtensor) + val = f(xval) + assert xval.__getitem__(slices).shape == val.shape + + +@pytest.mark.parametrize( + "x, s, idx, x_val, s_val", + [ + ( + vector(), + (iscalar(),), + (1,), + np.array([1, 2], dtype=config.floatX), + np.array([2], dtype=np.int64), + ), + ( + matrix(), + (iscalar(), iscalar()), + (1,), + np.array([[1, 2], [3, 4]], dtype=config.floatX), + np.array([2, 2], dtype=np.int64), + ), + ( + matrix(), + (iscalar(), iscalar()), + (0,), + np.array([[1, 2, 3], [4, 5, 6]], dtype=config.floatX), + np.array([2, 3], dtype=np.int64), + ), + ( + matrix(), + (iscalar(), iscalar()), + (1, 1), + np.array([[1, 2, 3], [4, 5, 6]], dtype=config.floatX), + np.array([2, 3], dtype=np.int64), + ), + ( + tensor3(), + (iscalar(), iscalar(), iscalar()), + (-1,), + np.arange(2 * 3 * 5, dtype=config.floatX).reshape((2, 3, 5)), + np.array([2, 3, 5], dtype=np.int64), + ), + ( + tensor3(), + (iscalar(), iscalar(), iscalar()), + (-1, 0), + np.arange(2 * 3 * 5, dtype=config.floatX).reshape((2, 3, 5)), + np.array([2, 3, 5], dtype=np.int64), + ), + ], +) +def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val): + y = specify_shape(x, s)[idx] + assert isinstance(y.owner.inputs[0].owner.op, SpecifyShape) + + rewrites = RewriteDatabaseQuery(include=[None]) + no_rewrites_mode = Mode(optimizer=rewrites) + + y_val_fn = function([x, *s], y, on_unused_input="ignore", mode=no_rewrites_mode) + y_val = y_val_fn(*([x_val, *s_val])) + + # This optimization should appear in the canonicalizations + y_opt = rewrite_graph(y, clone=False) + + if y.ndim == 0: + # SpecifyShape should be removed altogether + assert isinstance(y_opt.owner.op, Subtensor) + assert y_opt.owner.inputs[0] is x + else: + assert isinstance(y_opt.owner.op, SpecifyShape) + + y_opt_fn = function([x, *s], y_opt, on_unused_input="ignore") + y_opt_val = y_opt_fn(*([x_val, *s_val])) + + assert np.allclose(y_val, y_opt_val) + + +@pytest.mark.parametrize( + "x, s, idx", + [ + ( + matrix(), + (iscalar(), iscalar()), + (slice(1, None),), + ), + ( + matrix(), + (iscalar(), iscalar()), + (slicetype(),), + ), + ( + matrix(), + (iscalar(), iscalar()), + (1, 0), + ), + ], +) +def test_local_subtensor_SpecifyShape_lift_fail(x, s, idx): + y = specify_shape(x, s)[idx] + + # This optimization should appear in the canonicalizations + y_opt = rewrite_graph(y, clone=False) + + assert not isinstance(y_opt.owner.op, SpecifyShape) + + +class TestLocalSubtensorMakeVector: + mode = get_mode("FAST_RUN").including("local_subtensor_make_vector") + + def test_scalar_idx(self): + x, y, z = lscalars("xyz") + v = make_vector(x, y, z) + f = function([x, y, z], v[0], mode=self.mode) + + prog = f.maker.fgraph.toposort() + assert len(prog) == 1 + assert isinstance(prog[0].op, DeepCopyOp) + assert f(0, 1, 2) == 0 + + def test_idx_symbolic(self): + x, y, z = iscalars("xyz") + v = MakeVector("int32")(x, y, z) + idx = pt.as_tensor([0], dtype=np.int64) + f = function([x, y, z], v[idx], mode=self.mode) + + opt_fgraph = f.maker.fgraph + assert opt_fgraph.outputs[0].dtype == "int32" + assert isinstance(opt_fgraph.outputs[0].owner.op, MakeVector) + assert f(0, 1, 2) == np.array([0], dtype=np.int32) + + def test_slice_idx_start(self): + x, y, z = iscalars("xyz") + v = MakeVector("int32")(x, y, z) + f = function([x, y, z], v[1:], mode=self.mode, on_unused_input="ignore") + + opt_fgraph = f.maker.fgraph + assert opt_fgraph.outputs[0].dtype == "int32" + assert isinstance(opt_fgraph.outputs[0].owner.op, MakeVector) + assert len(opt_fgraph.outputs[0].owner.inputs) == 2 + r = f(0, 1, 2) + assert r[0] == 1 and r[1] == 2 + + def test_slice_idx_stop(self): + x, y, z = lscalars("xyz") + v = make_vector(x, y, z) + f = function([x, y, z], v[:2], mode=self.mode) + + prog = f.maker.fgraph.toposort() + assert len(prog) == 1 + assert isinstance(prog[0].op, MakeVector) + assert len(prog[0].inputs) == 2 + r = f(0, 1, 2) + assert r[0] == 0 and r[1] == 1 + + def test_slice_idx_step(self): + x, y, z = lscalars("xyz") + v = make_vector(x, y, z) + f = function([x, y, z], v[::2], mode=self.mode) + + prog = f.maker.fgraph.toposort() + assert len(prog) == 1 + assert isinstance(prog[0].op, MakeVector) + assert len(prog[0].inputs) == 2 + r = f(0, 1, 2) + assert r[0] == 0 and r[1] == 2 + + def test_AdvancedSubtensor1_idx(self): + x, y, z = lscalars("xyz") + v = make_vector(x, y, z) + f = function([x, y, z], v[[0, 2]], mode=self.mode) + + prog = f.maker.fgraph.toposort() + assert len(prog) == 1 + assert isinstance(prog[0].op, MakeVector) + assert len(prog[0].inputs) == 2 + r = f(0, 1, 2) + assert r[0] == 0 and r[1] == 2 + + def test_MakeVector_idx(self): + x, y, z, q = lscalars("xyzq") + v = make_vector(x, y, z) + q = make_vector(0, 2) + f = function([x, y, z], v[q], mode=self.mode) + + prog = f.maker.fgraph.toposort() + assert len(prog) == 1 + assert isinstance(prog[0].op, MakeVector) + assert len(prog[0].inputs) == 2 + r = f(0, 1, 2) + assert r[0] == 0 and r[1] == 2 + + def test_stack_trace(self): + x, y, z = lscalars("xyz") + v = make_vector(x, y, z) + + mode = get_default_mode().including("local_subtensor_make_vector") + + # list of subtensor cases, where local_subtensor_make_vector + # inserts a new MakeVector node + v_subtensors = [v[:2], v[::2], v[[0, 2]]] + + for v_subtensor in v_subtensors: + f = function([x, y, z], v_subtensor, mode=mode) + assert check_stack_trace(f, ops_to_check="all") + + def test_empty_subtensor(self): + x, y = lscalars("xy") + v = make_vector(x, y) + out = v[()] + + fgraph = FunctionGraph(outputs=[out], clone=False) + node = fgraph.outputs[0].owner + assert isinstance(node.op, Subtensor) + + assert local_subtensor_make_vector.transform(fgraph, node) == [v] + + +def test_local_subtensor_shape_constant(): + x = tensor(dtype=np.float64, shape=(1, None)).shape[0] + (res,) = local_subtensor_shape_constant.transform(None, x.owner) + assert isinstance(res, Constant) + assert res.data == 1 + + # Make sure it's part of the canonicalizations + res = rewrite_graph(x) + assert isinstance(res, Constant) + assert res.data == 1 + + x = _shape(tensor(dtype=np.float64, shape=(1, None)))[lscalar()] + assert not local_subtensor_shape_constant.transform(None, x.owner) + + x = _shape(tensor(dtype=np.float64, shape=(1, None)))[0:] + assert not local_subtensor_shape_constant.transform(None, x.owner) + + x = _shape(tensor(dtype=np.float64, shape=(1, None)))[lscalar() :] + assert not local_subtensor_shape_constant.transform(None, x.owner) + + x = _shape(tensor(dtype=np.float64, shape=(1, 1)))[1:] + (res,) = local_subtensor_shape_constant.transform(None, x.owner) + assert isinstance(res, Constant) + assert np.array_equal(res.data, [1]) + + x = _shape(tensor(dtype=np.float64, shape=(None, 1, 1)))[1:] + (res,) = local_subtensor_shape_constant.transform(None, x.owner) + assert isinstance(res, Constant) + assert np.array_equal(res.data, [1, 1]) + + # A test for a non-`TensorType` + class MyType(Type): + def filter(self, *args, **kwargs): + raise NotImplementedError() + + def __eq__(self, other): + return isinstance(other, MyType) and other.thingy == self.thingy + + x = shape(Variable(MyType(), None, None))[0] + + assert not local_subtensor_shape_constant.transform(None, x.owner) From 7de52829477dfc41dd01a71fdff2f38042c5caa8 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sun, 8 Dec 2024 11:55:58 +0100 Subject: [PATCH 02/10] Group subtensor specify_shape lift tests in class --- tests/tensor/rewriting/test_subtensor_lift.py | 200 +++++++++--------- 1 file changed, 100 insertions(+), 100 deletions(-) diff --git a/tests/tensor/rewriting/test_subtensor_lift.py b/tests/tensor/rewriting/test_subtensor_lift.py index 6fc30b8d0d..49fce910bb 100644 --- a/tests/tensor/rewriting/test_subtensor_lift.py +++ b/tests/tensor/rewriting/test_subtensor_lift.py @@ -305,106 +305,106 @@ def test_local_subtensor_of_alloc(): assert xval.__getitem__(slices).shape == val.shape -@pytest.mark.parametrize( - "x, s, idx, x_val, s_val", - [ - ( - vector(), - (iscalar(),), - (1,), - np.array([1, 2], dtype=config.floatX), - np.array([2], dtype=np.int64), - ), - ( - matrix(), - (iscalar(), iscalar()), - (1,), - np.array([[1, 2], [3, 4]], dtype=config.floatX), - np.array([2, 2], dtype=np.int64), - ), - ( - matrix(), - (iscalar(), iscalar()), - (0,), - np.array([[1, 2, 3], [4, 5, 6]], dtype=config.floatX), - np.array([2, 3], dtype=np.int64), - ), - ( - matrix(), - (iscalar(), iscalar()), - (1, 1), - np.array([[1, 2, 3], [4, 5, 6]], dtype=config.floatX), - np.array([2, 3], dtype=np.int64), - ), - ( - tensor3(), - (iscalar(), iscalar(), iscalar()), - (-1,), - np.arange(2 * 3 * 5, dtype=config.floatX).reshape((2, 3, 5)), - np.array([2, 3, 5], dtype=np.int64), - ), - ( - tensor3(), - (iscalar(), iscalar(), iscalar()), - (-1, 0), - np.arange(2 * 3 * 5, dtype=config.floatX).reshape((2, 3, 5)), - np.array([2, 3, 5], dtype=np.int64), - ), - ], -) -def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val): - y = specify_shape(x, s)[idx] - assert isinstance(y.owner.inputs[0].owner.op, SpecifyShape) - - rewrites = RewriteDatabaseQuery(include=[None]) - no_rewrites_mode = Mode(optimizer=rewrites) - - y_val_fn = function([x, *s], y, on_unused_input="ignore", mode=no_rewrites_mode) - y_val = y_val_fn(*([x_val, *s_val])) - - # This optimization should appear in the canonicalizations - y_opt = rewrite_graph(y, clone=False) - - if y.ndim == 0: - # SpecifyShape should be removed altogether - assert isinstance(y_opt.owner.op, Subtensor) - assert y_opt.owner.inputs[0] is x - else: - assert isinstance(y_opt.owner.op, SpecifyShape) - - y_opt_fn = function([x, *s], y_opt, on_unused_input="ignore") - y_opt_val = y_opt_fn(*([x_val, *s_val])) - - assert np.allclose(y_val, y_opt_val) - - -@pytest.mark.parametrize( - "x, s, idx", - [ - ( - matrix(), - (iscalar(), iscalar()), - (slice(1, None),), - ), - ( - matrix(), - (iscalar(), iscalar()), - (slicetype(),), - ), - ( - matrix(), - (iscalar(), iscalar()), - (1, 0), - ), - ], -) -def test_local_subtensor_SpecifyShape_lift_fail(x, s, idx): - y = specify_shape(x, s)[idx] - - # This optimization should appear in the canonicalizations - y_opt = rewrite_graph(y, clone=False) - - assert not isinstance(y_opt.owner.op, SpecifyShape) +class TestLocalSubtensorSpecifyShapeLift: + @pytest.mark.parametrize( + "x, s, idx, x_val, s_val", + [ + ( + vector(), + (iscalar(),), + (1,), + np.array([1, 2], dtype=config.floatX), + np.array([2], dtype=np.int64), + ), + ( + matrix(), + (iscalar(), iscalar()), + (1,), + np.array([[1, 2], [3, 4]], dtype=config.floatX), + np.array([2, 2], dtype=np.int64), + ), + ( + matrix(), + (iscalar(), iscalar()), + (0,), + np.array([[1, 2, 3], [4, 5, 6]], dtype=config.floatX), + np.array([2, 3], dtype=np.int64), + ), + ( + matrix(), + (iscalar(), iscalar()), + (1, 1), + np.array([[1, 2, 3], [4, 5, 6]], dtype=config.floatX), + np.array([2, 3], dtype=np.int64), + ), + ( + tensor3(), + (iscalar(), iscalar(), iscalar()), + (-1,), + np.arange(2 * 3 * 5, dtype=config.floatX).reshape((2, 3, 5)), + np.array([2, 3, 5], dtype=np.int64), + ), + ( + tensor3(), + (iscalar(), iscalar(), iscalar()), + (-1, 0), + np.arange(2 * 3 * 5, dtype=config.floatX).reshape((2, 3, 5)), + np.array([2, 3, 5], dtype=np.int64), + ), + ], + ) + def test_local_subtensor_SpecifyShape_lift(self, x, s, idx, x_val, s_val): + y = specify_shape(x, s)[idx] + assert isinstance(y.owner.inputs[0].owner.op, SpecifyShape) + + rewrites = RewriteDatabaseQuery(include=[None]) + no_rewrites_mode = Mode(optimizer=rewrites) + + y_val_fn = function([x, *s], y, on_unused_input="ignore", mode=no_rewrites_mode) + y_val = y_val_fn(*([x_val, *s_val])) + + # This optimization should appear in the canonicalizations + y_opt = rewrite_graph(y, clone=False) + + if y.ndim == 0: + # SpecifyShape should be removed altogether + assert isinstance(y_opt.owner.op, Subtensor) + assert y_opt.owner.inputs[0] is x + else: + assert isinstance(y_opt.owner.op, SpecifyShape) + + y_opt_fn = function([x, *s], y_opt, on_unused_input="ignore") + y_opt_val = y_opt_fn(*([x_val, *s_val])) + + assert np.allclose(y_val, y_opt_val) + + @pytest.mark.parametrize( + "x, s, idx", + [ + ( + matrix(), + (iscalar(), iscalar()), + (slice(1, None),), + ), + ( + matrix(), + (iscalar(), iscalar()), + (slicetype(),), + ), + ( + matrix(), + (iscalar(), iscalar()), + (1, 0), + ), + ], + ) + def test_local_subtensor_SpecifyShape_lift_fail(self, x, s, idx): + y = specify_shape(x, s)[idx] + + # This optimization should appear in the canonicalizations + y_opt = rewrite_graph(y, clone=False) + + assert not isinstance(y_opt.owner.op, SpecifyShape) class TestLocalSubtensorMakeVector: From 19550e8f294ea6cdda5197d65a2b621e8d9002f6 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 20 Jan 2025 11:49:12 +0100 Subject: [PATCH 03/10] Cache sub-type of DimShuffle --- pytensor/tensor/elemwise.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index cb60427ba0..c37597906a 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -166,15 +166,20 @@ def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]): self.transposition = self.shuffle + drop # List of dimensions of the output that are broadcastable and were not # in the original input - self.augment = sorted(i for i, x in enumerate(new_order) if x == "x") + self.augment = augment = sorted(i for i, x in enumerate(new_order) if x == "x") self.drop = drop - self.is_left_expand_dims = self.augment and ( + dims_are_shuffled = sorted(self.shuffle) != self.shuffle + + self.is_transpose = dims_are_shuffled and not augment and not drop + self.is_squeeze = drop and not dims_are_shuffled and not augment + self.is_expand_dims = augment and not dims_are_shuffled and not drop + self.is_left_expand_dims = self.is_expand_dims and ( input_ndim == 0 or new_order[-input_ndim:] == list(range(input_ndim)) ) - self.is_right_expand_dims = self.augment and new_order[:input_ndim] == list( - range(input_ndim) - ) + self.is_right_expand_dims = self.is_expand_dims and new_order[ + :input_ndim + ] == list(range(input_ndim)) if self.inplace: self.view_map = {0: [0]} @@ -215,16 +220,15 @@ def make_node(self, inp): return Apply(self, [input], [output]) def __str__(self): - shuffle = sorted(self.shuffle) != self.shuffle - if self.augment and not (shuffle or self.drop): + if self.is_expand_dims: if len(self.augment) == 1: return f"ExpandDims{{axis={self.augment[0]}}}" return f"ExpandDims{{axes={self.augment}}}" - if self.drop and not (self.augment or shuffle): + if self.is_squeeze: if len(self.drop) == 1: - return f"DropDims{{axis={self.drop[0]}}}" - return f"DropDims{{axes={self.drop}}}" - if shuffle and not (self.augment or self.drop): + return f"Squeeze{{axis={self.drop[0]}}}" + return f"Squeeze{{axes={self.drop}}}" + if self.is_transpose: return f"Transpose{{axes={self.shuffle}}}" return f"DimShuffle{{order=[{','.join(map(str, self.new_order))}]}}" From 2ad449b027cef6e35e1163bcb1740933857be2da Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 20 Jan 2025 11:49:47 +0100 Subject: [PATCH 04/10] Lift Subtensor over expand_dims --- pytensor/tensor/rewriting/subtensor.py | 26 +++--- pytensor/tensor/rewriting/subtensor_lift.py | 80 ++++++++++++++++++- tests/tensor/rewriting/test_subtensor_lift.py | 70 +++++++++++++++- 3 files changed, 157 insertions(+), 19 deletions(-) diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index ca27761319..94f4126609 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -74,7 +74,7 @@ indices_from_subtensor, ) from pytensor.tensor.type import TensorType, integer_dtypes -from pytensor.tensor.type_other import NoneTypeT, SliceConstant, SliceType +from pytensor.tensor.type_other import NoneTypeT, SliceType from pytensor.tensor.variable import TensorConstant, TensorVariable @@ -154,19 +154,17 @@ def transform_take(a, indices, axis): def is_full_slice(x): """Determine if `x` is a ``slice(None)`` or a symbolic equivalent.""" - if ( - (isinstance(x, slice) and x == slice(None)) - or (isinstance(x, SliceConstant) and x.value == slice(None)) - or ( - not isinstance(x, SliceConstant) - and isinstance(getattr(x, "type", None), SliceType) - and x.owner is not None - and all( - isinstance(getattr(i, "type", None), NoneTypeT) for i in x.owner.inputs - ) - ) - ): - return True + if isinstance(x, slice): + return x == slice(None) + + if isinstance(x, Variable) and isinstance(x.type, SliceType): + if isinstance(x, Constant): + return x.data == slice(None) + else: + # Symbolic MakeSlice + # Ignores start = 0, step = 1 cases + return all(isinstance(i.type, NoneTypeT) for i in x.owner.inputs) + return False diff --git a/pytensor/tensor/rewriting/subtensor_lift.py b/pytensor/tensor/rewriting/subtensor_lift.py index 839d43f53f..606121f125 100644 --- a/pytensor/tensor/rewriting/subtensor_lift.py +++ b/pytensor/tensor/rewriting/subtensor_lift.py @@ -11,10 +11,11 @@ MakeVector, alloc, as_tensor, + expand_dims, get_underlying_scalar_constant_value, register_infer_shape, ) -from pytensor.tensor.elemwise import Elemwise +from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import Dot, ceil_intdiv, dot from pytensor.tensor.rewriting.basic import ( @@ -22,7 +23,7 @@ register_specialize, register_stabilize, ) -from pytensor.tensor.rewriting.subtensor import register_useless +from pytensor.tensor.rewriting.subtensor import is_full_slice, register_useless from pytensor.tensor.shape import ( Shape, SpecifyShape, @@ -37,6 +38,7 @@ get_canonical_form_slice, get_constant_idx, get_idx_list, + indices_from_subtensor, ) from pytensor.tensor.type import TensorType from pytensor.tensor.type_other import SliceType @@ -204,6 +206,80 @@ def local_subtensor_lift(fgraph, node): return [rbcast_subt_x] +@register_canonicalize("shape_unsafe") +@register_specialize("shape_unsafe") +@node_rewriter([Subtensor]) +def local_subtensor_of_expand_dims(fgraph, node): + """Lift a Subtensor through a DimShuffle that only expands dims. + + expand_dims(x, axis=0)[0] -> x + expand_dims(x, axis=0)[:, 0] -> expand_dims(x[0], axis=0) + expand_dims(x, axis=2)[0] -> expand_dims(x[0], axis=1) + + This goes beyond `local_subtensor_remove_broadcastable_index` which + simply removes useless subtensors on broadcastable dimensions. + """ + ds, *idx = node.inputs + + if not (ds.owner and isinstance(ds.owner.op, DimShuffle)): + return None + + ds_op = ds.owner.op + + if not ds_op.is_expand_dims: + return None + + expanded_axes = ds_op.augment + [x] = ds.owner.inputs + + idx_tuple = indices_from_subtensor(idx, node.op.idx_list) + + # Keep indexes for the original dimensions, and drop indexes for the expanded dimensions when safe + new_idxs = [] + for i, idx_item in enumerate(idx_tuple): + if i in expanded_axes: + if isinstance(idx_item, slice): + # Slice could be keeping or dropping this dimension + if is_full_slice(idx_item): + # A None slice, always keeps the dimension. + # We skip the index, and later introduce the needed expand_dim + continue + else: + # Other slices could keep or drop the dimension. + # Get out instead o trying to figure out which case it is + return None + else: + # Integer indexing can only drop the dimension (if it's a valid graph) + # We can just drop the index and avoid expanding the dimension + # This is why this rewrite is tagged with "shape_unsafe" + continue + else: + # Keep indexes for non-expanded dimensions + new_idxs.append(idx_item) + + [old_out] = node.outputs + out = x[tuple(new_idxs)] + copy_stack_trace(old_out, out) + + if out.type.broadcastable != old_out.type.broadcastable: + # Re-introduce needed new dimensions (corresponding to full slices on the original expanded dimensions) + # If out.type.broadcastable == (False) and old_out.type.broadcastable == (True, False, True) + # then axis = (0, 2) + old_bcast = list(old_out.type.broadcastable) + expanded_bcast = list(out.type.broadcastable) + axis = [] + i = 0 + while i < len(old_bcast): + if i == len(expanded_bcast) or expanded_bcast[i] != old_bcast[i]: + expanded_bcast.insert(i, True) + axis.append(i) + i += 1 + out = expand_dims(out, axis=axis) + copy_stack_trace(old_out, out) + + return [out] + + @register_infer_shape @register_useless @register_canonicalize diff --git a/tests/tensor/rewriting/test_subtensor_lift.py b/tests/tensor/rewriting/test_subtensor_lift.py index 49fce910bb..5ecef0ac67 100644 --- a/tests/tensor/rewriting/test_subtensor_lift.py +++ b/tests/tensor/rewriting/test_subtensor_lift.py @@ -24,7 +24,9 @@ Type, rewrite_graph, ) +from pytensor.graph.basic import equal_computations from pytensor.graph.rewriting.basic import check_stack_trace +from pytensor.printing import debugprint from pytensor.tensor import ( add, exp, @@ -43,7 +45,7 @@ tensor3, vector, ) -from pytensor.tensor.basic import MakeVector, make_vector +from pytensor.tensor.basic import MakeVector, expand_dims, make_vector from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.rewriting.subtensor_lift import ( local_subtensor_make_vector, @@ -53,6 +55,9 @@ from pytensor.tensor.subtensor import Subtensor +NO_OPTIMIZATION_MODE = Mode(linker="py", optimizer=None) + + class TestLocalSubtensorLift: def test_basic(self): # basic test that the Op works @@ -134,8 +139,8 @@ def test_basic_4(self): assert check_stack_trace(f, ops_to_check="all") prog = f.maker.fgraph.toposort() - assert isinstance(prog[0].op, DimShuffle) - assert isinstance(prog[1].op, Subtensor) + assert isinstance(prog[0].op, Subtensor) + assert isinstance(prog[1].op, DimShuffle) assert prog[2].op == exp assert len(prog) == 3 f([4, 5]) # let debugmode test something @@ -256,6 +261,65 @@ def test_basic_8(self): assert (f4(zval) == zval[:, 3, 0]).all() +@pytest.mark.parametrize( + "original_fn, expected_fn", + [ + # Integer indexing + (lambda x: expand_dims(x, axis=0)[0], lambda x: x), + ( + lambda x: expand_dims(x, axis=1)[0], + lambda x: expand_dims(x[0], axis=0), + ), + ( + lambda x: expand_dims(x, axis=(1, 3))[0], + lambda x: expand_dims(x[0], axis=(0, 2)), + ), + # Slice indexing + ( + lambda x: expand_dims(x, axis=1)[1:], + lambda x: expand_dims(x[1:], axis=1), + ), + ( + lambda x: expand_dims(x, axis=(1, 3))[1:], + lambda x: expand_dims(x[1:], axis=(1, 3)), + ), + # Not supported, slice indexing on expanded dimension + ( + lambda x: expand_dims(x, axis=0)[1:], + lambda x: expand_dims(x, axis=0)[1:], + ), + # Mixed indexing + ( + lambda x: expand_dims(x, axis=1)[0, :, 1:], + lambda x: expand_dims(x[0, 1:], axis=0), + ), + ( + lambda x: expand_dims(x, axis=1)[1:, :, 0], + lambda x: expand_dims(x[1:, 0], axis=1), + ), + ( + lambda x: expand_dims(x, axis=(1, 2))[1:, :, 0], + lambda x: expand_dims(x[1:], axis=1), + ), + ], +) +def test_local_subtensor_of_expand_dims(original_fn, expected_fn): + rng = np.random.default_rng(232) + x = tensor("x", shape=(5, 3)) + x_test = rng.normal(size=x.type.shape) + + out = original_fn(x) + expected_opt_out = expected_fn(x) + opt_out = rewrite_graph(out, exclude=["local_uint_constant_indices"]) + assert equal_computations([opt_out], [expected_opt_out]), debugprint( + [opt_out, expected_opt_out], print_type=True + ) + np.testing.assert_allclose( + opt_out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE), + out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE), + ) + + def test_local_subtensor_of_alloc(): # DebugMode should detect if something goes wrong. # test shape combination of odd and event shape. From 5c97e9f3873c023f2736f2418abfa6fe638d393b Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 20 Jan 2025 12:15:37 +0100 Subject: [PATCH 05/10] Lift Subtensor over transpose --- pytensor/tensor/rewriting/subtensor_lift.py | 58 ++++++++++++++++++- tests/tensor/rewriting/test_subtensor_lift.py | 31 +++++++++- 2 files changed, 87 insertions(+), 2 deletions(-) diff --git a/pytensor/tensor/rewriting/subtensor_lift.py b/pytensor/tensor/rewriting/subtensor_lift.py index 606121f125..1de943b223 100644 --- a/pytensor/tensor/rewriting/subtensor_lift.py +++ b/pytensor/tensor/rewriting/subtensor_lift.py @@ -1,4 +1,4 @@ -from collections.abc import Iterable +from collections.abc import Iterable, Sequence import numpy as np @@ -17,12 +17,14 @@ ) from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError +from pytensor.tensor.extra_ops import squeeze from pytensor.tensor.math import Dot, ceil_intdiv, dot from pytensor.tensor.rewriting.basic import ( register_canonicalize, register_specialize, register_stabilize, ) +from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift from pytensor.tensor.rewriting.subtensor import is_full_slice, register_useless from pytensor.tensor.shape import ( Shape, @@ -44,6 +46,12 @@ from pytensor.tensor.type_other import SliceType +def _dims_dropped_by_basic_index(idxs: Sequence[slice | int]) -> tuple[int, ...]: + # Inputs can be slice or integer indexes + # Slices keep the dimensions, integers collapse them + return tuple(i for i, idx in enumerate(idxs) if not isinstance(idx, slice)) + + @register_canonicalize @register_stabilize @register_specialize @@ -280,6 +288,54 @@ def local_subtensor_of_expand_dims(fgraph, node): return [out] +@register_canonicalize +@register_specialize +@node_rewriter([Subtensor]) +def local_subtensor_of_transpose(fgraph, node): + """Lift a Subtensor through a DimShuffle that only transposes. + + transpose(x, (1, 0, 2))[i:, j:, k:] -> transpose(x[j:, i:, k:], (1, 0, 2)) + """ + ds, *idx = node.inputs + + if not (ds.owner and isinstance(ds.owner.op, DimShuffle)): + return None + + ds_op = ds.owner.op + if not ds_op.is_transpose: + return None + + transposition = ds_op.transposition + [x] = ds.owner.inputs + + idx_tuple = indices_from_subtensor(idx, node.op.idx_list) + + # Apply the transposition to the indexes + n_implicit_idxs = x.type.ndim - len(idx_tuple) + idx_tuple = idx_tuple + (slice(None),) * n_implicit_idxs + new_idxs = [idx_tuple[i] for i in transposition] + new_x = x[tuple(new_idxs)] + + # Reintroduce any dims dropped by indexing so the original transpose still works + dims_dropped_by_new_idx = _dims_dropped_by_basic_index(new_idxs) + if dims_dropped_by_new_idx: + new_x = expand_dims(new_x, axis=dims_dropped_by_new_idx) + + # Apply the transpose + new_out = ds_op(new_x) + + # Squeeze dims again now that the transpose is done + if dims_dropped_by_new_idx: + dims_dropped_by_original_idx = _dims_dropped_by_basic_index(idx_tuple) + new_out = squeeze(new_out, axis=dims_dropped_by_original_idx) + + # Cleanup consecutive expand_dims / transpose / squeeze (if any) + if dims_dropped_by_new_idx: + [new_out] = local_dimshuffle_lift.transform(fgraph, new_out.owner) + + return [new_out] + + @register_infer_shape @register_useless @register_canonicalize diff --git a/tests/tensor/rewriting/test_subtensor_lift.py b/tests/tensor/rewriting/test_subtensor_lift.py index 5ecef0ac67..036ff43285 100644 --- a/tests/tensor/rewriting/test_subtensor_lift.py +++ b/tests/tensor/rewriting/test_subtensor_lift.py @@ -310,7 +310,7 @@ def test_local_subtensor_of_expand_dims(original_fn, expected_fn): out = original_fn(x) expected_opt_out = expected_fn(x) - opt_out = rewrite_graph(out, exclude=["local_uint_constant_indices"]) + opt_out = rewrite_graph(out) assert equal_computations([opt_out], [expected_opt_out]), debugprint( [opt_out, expected_opt_out], print_type=True ) @@ -320,6 +320,35 @@ def test_local_subtensor_of_expand_dims(original_fn, expected_fn): ) +@pytest.mark.parametrize( + "original_fn, expected_fn", + [ + (lambda x: x.transpose(2, 1, 0)[0], lambda x: x[:, :, 0].transpose(1, 0)), + (lambda x: x.transpose(2, 1, 0)[:, :, 1:], lambda x: x[1:].transpose(2, 1, 0)), + ( + lambda x: x.transpose(2, 1, 0)[0, :1, 1:], + lambda x: x[1:, :1, 0].transpose(1, 0), + ), + (lambda x: x.transpose(2, 1, 0)[0, :1, 1], lambda x: x[1, :1, 0]), + ], +) +def test_local_subtensor_of_transpose(original_fn, expected_fn): + rng = np.random.default_rng(232) + x = tensor("x", shape=(7, 5, 3)) + x_test = rng.normal(size=x.type.shape) + + out = original_fn(x) + expected_opt_out = expected_fn(x) + opt_out = rewrite_graph(out) + assert equal_computations([opt_out], [expected_opt_out]), debugprint( + [expected_opt_out, opt_out], print_type=True + ) + np.testing.assert_allclose( + opt_out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE), + out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE), + ) + + def test_local_subtensor_of_alloc(): # DebugMode should detect if something goes wrong. # test shape combination of odd and event shape. From e1ee3a2c079fb5d926271208ba329e8bf3a498bb Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 20 Jan 2025 12:51:34 +0100 Subject: [PATCH 06/10] Generalize lift of Subtensor over Elemwise Split off Subtensor of Unbroadcast into its own rewrite --- pytensor/tensor/rewriting/subtensor_lift.py | 204 +++++------ tests/tensor/rewriting/test_subtensor_lift.py | 316 ++++++++---------- 2 files changed, 256 insertions(+), 264 deletions(-) diff --git a/pytensor/tensor/rewriting/subtensor_lift.py b/pytensor/tensor/rewriting/subtensor_lift.py index 1de943b223..465a73f6dd 100644 --- a/pytensor/tensor/rewriting/subtensor_lift.py +++ b/pytensor/tensor/rewriting/subtensor_lift.py @@ -110,108 +110,79 @@ def local_subtensor_of_dot(fgraph, node): return [r] -# fast_compile to allow opt subtensor(cast{float32}(make_vector)) -@register_canonicalize("fast_compile") +@register_canonicalize("shape_unsafe") +@register_specialize("shape_unsafe") @node_rewriter([Subtensor]) -def local_subtensor_lift(fgraph, node): +def local_subtensor_of_elemwise(fgraph, node): + """Lift a Subtensor through an Elemwise and its implicit broadcasting behavior. + + exp(x)[:, 0] -> exp(x[:, 0]) + add(x, y)[0] -> add(x[0], y[0]) + add(x[None], y)[2] -> add(x, y[2]) """ - unary(x)[idx] -> unary(x[idx])#any broadcast pattern. + elem, *idx = node.inputs - Handles the following unary ops: - elemwise(x,...)[idx] -> elemwise(x[idx],...) - when x,... are broadcasted scalar or not broadcasted at all - Unbroadcast(x)[idx] => Unbroadcast(x[idx]) + if not (elem.owner and isinstance(elem.owner.op, Elemwise)): + return None - """ - if isinstance(node.op, Subtensor): - u = node.inputs[0] - if u.owner is None or len(fgraph.clients[u]) > 1: - return False - - if isinstance(u.owner.op, Elemwise) and len(u.owner.inputs) == 1: - idx = node.inputs[1:] - x_idx = node.op(u.owner.inputs[0], *idx) - # Copy over previous output stacktrace - copy_stack_trace(node.outputs, x_idx) - ret = u.owner.op(x_idx) - # Copy over previous output stacktrace - # and stacktrace from previous unary operation - copy_stack_trace([node.outputs[0], node.inputs[0]], ret) - return [ret] + if len(fgraph.clients[elem]) > 1: + # Elemwise output is used beyond the Subtensor. + # Get out to avoid repeated computations + return None - if isinstance(u.owner.op, Elemwise): - new_inputs = [] - if all(sum(i.type.broadcastable) == 0 for i in u.owner.inputs): - # There is no broadcastable in the inputs - idx = node.inputs[1:] - new_inputs = [node.op(i, *idx) for i in u.owner.inputs] - # Copy over previous output stacktrace - copy_stack_trace(node.outputs[0], new_inputs) - - ret = u.owner.op(*new_inputs) - # Copy over previous output stacktrace - # and stacktrace from previous unary operation - copy_stack_trace([node.outputs[0], node.inputs[0]], ret) - return [ret] - elif all(sum(i.type.broadcastable) in [i.ndim, 0] for i in u.owner.inputs): - # There is no broadcastable in the inputs or it is scalar - idx = node.inputs[1:] - new_inputs = [] - for i in u.owner.inputs: - if sum(i.type.broadcastable) == 0: - new_inputs.append(node.op(i, *idx)) - else: - # If the subtensor remove some dims, we must - # lower the number of dimensions of this scalar. - if node.outputs[0].ndim == i.ndim: - new_inputs.append(i) - else: - new_inputs.append( - i.dimshuffle(["x"] * node.outputs[0].ndim) - ) - - # Copy over previous output stacktrace - copy_stack_trace(node.outputs[0], new_inputs) - - ret = u.owner.op(*new_inputs) - # Copy over previous output stacktrace - # and stacktrace from previous unary operation - copy_stack_trace([node.outputs[0], node.inputs[0]], ret) - return [ret] - - if isinstance(u.owner.op, Unbroadcast): - # Subtensor might reduce dim., adapt broadcast pattern accordingly - old_axes = u.owner.op.axes - new_axes = [] - - # loop through indices being subtensor-ed - # i indexes broadcastable pattern before subtensor - # j indexes broadcastable pattern after subtensor - j = 0 - for i, x in enumerate(node.op.idx_list): - # if it is not a slice, it will reduce the dimension, should - # not appear in the broascastable dimensions - if isinstance(x, slice): - if i in old_axes: - new_axes.append(j) - j += 1 - # now keep the broadcastable pattern of all - # items not appearing in subtensor list - for i in range(len(node.op.idx_list), len(u.broadcastable)): - if i in old_axes: - new_axes.append(j) - j += 1 + idx_tuple = indices_from_subtensor(idx, node.op.idx_list) - subt_x = node.op(u.owner.inputs[0], *node.inputs[1:]) - # Copy over previous output stacktrace - copy_stack_trace(node.outputs[0], subt_x) + elem_inputs = elem.owner.inputs + elem_bcast = elem.type.broadcastable + if all(inp.type.broadcastable == elem_bcast for inp in elem_inputs): + # No need to worry about implicit broadcasting. + indexed_inputs = [inp[idx_tuple] for inp in elem_inputs] + + else: + # The original indices may not make sense on some of the broadcasted dimensions + new_idxs = [list(idx_tuple) for _ in elem_inputs] + for dim, (dim_idx, dim_bcast_out, *dim_bcast_inputs) in enumerate( + zip( + idx_tuple, + elem_bcast, + *(inp.type.broadcastable for inp in elem_inputs), + # Indices can be shorter than input ndims + strict=False, + ) + ): + if is_full_slice(dim_idx): + # Full slice can be safely applied to all inputs + continue - rbcast_subt_x = unbroadcast(subt_x, *new_axes) - # Copy over previous output stacktrace - # and stacktrace from previous unary operation - copy_stack_trace([node.outputs[0], node.inputs[0]], rbcast_subt_x) + if all(dim_bcast_inp == elem_bcast for dim_bcast_inp in dim_bcast_inputs): + # This dim is not broadcasted for any of the inputs, original index can be applied to all inputs + continue + + # Some dims are broadcasted, so we need to adapt their indices + # Slice indexing keeps the dimension, so we use a full slice for broadcasted inputs + # Integer indexing drops the dimension, so we index by zero for the broadcsated inputs + safe_bcast_dim_idx = slice(None) if isinstance(dim_idx, slice) else 0 + for inp_idx, dim_bcast_inp in zip(new_idxs, dim_bcast_inputs, strict=True): + if dim_bcast_inp: + inp_idx[dim] = safe_bcast_dim_idx - return [rbcast_subt_x] + indexed_inputs = [ + inp[tuple(new_idx)] + for inp, new_idx in zip(elem_inputs, new_idxs, strict=True) + ] + + [old_out] = node.outputs + + # Copy stack trace to new inputs + [copy_stack_trace(old_out, new_inp) for new_inp in indexed_inputs] + + # Define elemwise operation on indexed inputs + new_out = elem.owner.op(*indexed_inputs) + + # Copy stack trace to new output + copy_stack_trace([old_out, *node.inputs], new_out) + + return [new_out] @register_canonicalize("shape_unsafe") @@ -336,6 +307,51 @@ def local_subtensor_of_transpose(fgraph, node): return [new_out] +@register_canonicalize("fast_compile") +@node_rewriter([Subtensor]) +def local_subtensor_of_unbroadcast(fgraph, node): + """ + Unbroadcast(x)[idx] => Unbroadcast(x[idx]) + """ + u = node.inputs[0] + if u.owner is None or len(fgraph.clients[u]) > 1: + return False + + if isinstance(u.owner.op, Unbroadcast): + # Subtensor might reduce dim., adapt broadcast pattern accordingly + old_axes = u.owner.op.axes + new_axes = [] + + # loop through indices being subtensor-ed + # i indexes broadcastable pattern before subtensor + # j indexes broadcastable pattern after subtensor + j = 0 + for i, x in enumerate(node.op.idx_list): + # if it is not a slice, it will reduce the dimension, should + # not appear in the broascastable dimensions + if isinstance(x, slice): + if i in old_axes: + new_axes.append(j) + j += 1 + # now keep the broadcastable pattern of all + # items not appearing in subtensor list + for i in range(len(node.op.idx_list), len(u.broadcastable)): + if i in old_axes: + new_axes.append(j) + j += 1 + + subt_x = node.op(u.owner.inputs[0], *node.inputs[1:]) + # Copy over previous output stacktrace + copy_stack_trace(node.outputs[0], subt_x) + + rbcast_subt_x = unbroadcast(subt_x, *new_axes) + # Copy over previous output stacktrace + # and stacktrace from previous unary operation + copy_stack_trace([node.outputs[0], node.inputs[0]], rbcast_subt_x) + + return [rbcast_subt_x] + + @register_infer_shape @register_useless @register_canonicalize diff --git a/tests/tensor/rewriting/test_subtensor_lift.py b/tests/tensor/rewriting/test_subtensor_lift.py index 036ff43285..a4622e363d 100644 --- a/tests/tensor/rewriting/test_subtensor_lift.py +++ b/tests/tensor/rewriting/test_subtensor_lift.py @@ -1,6 +1,5 @@ import numpy as np import pytest -import unittest_tools as utt from tensor.rewriting.test_subtensor import mode_opt from pytensor import ( @@ -30,14 +29,12 @@ from pytensor.tensor import ( add, exp, - inplace, iscalar, iscalars, lscalar, lscalars, matrix, row, - scalar, shape, slicetype, specify_shape, @@ -49,6 +46,7 @@ from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.rewriting.subtensor_lift import ( local_subtensor_make_vector, + local_subtensor_of_elemwise, local_subtensor_shape_constant, ) from pytensor.tensor.shape import SpecifyShape, Unbroadcast, _shape @@ -58,22 +56,8 @@ NO_OPTIMIZATION_MODE = Mode(linker="py", optimizer=None) -class TestLocalSubtensorLift: - def test_basic(self): - # basic test that the Op works - x = matrix("x") - f = function([x], exp(x)[0], mode=mode_opt) - - # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(f, ops_to_check="all") - - prog = f.maker.fgraph.toposort() - assert isinstance(prog[0].op, Subtensor) # first subtensor - assert prog[1].op == exp - assert len(prog) == 2 - f([[0, 1], [2, 3]]) # let debugmode test something - - def test_basic_1(self): +class TestLocalSubtensorOfElemwise: + def test_unary_multiple_clients(self): # as test0, but we reuse the output of the elemwise # So we should not lift the subtensor x = matrix("x") @@ -87,85 +71,16 @@ def test_basic_1(self): assert isinstance(prog[1].op, Subtensor) # first subtensor assert isinstance(prog[2].op, DeepCopyOp) assert len(prog) == 3 - f([[0, 1], [2, 3]]) # let debugmode test something - - def test_basic_2(self): - # basic test that the optimization work with scalar broadcasted - x = matrix("x") - y = scalar("y") - z = matrix("z") - f = function([x, y, z], exp(x + y + z)[0], mode=mode_opt) - - prog = f.maker.fgraph.toposort() - assert isinstance(prog[0].op, Subtensor) - assert isinstance(prog[1].op, DimShuffle) - assert isinstance(prog[2].op, Subtensor) - assert isinstance(prog[3].op.scalar_op, ps.Composite) # Composite{add,add} - assert len(prog) == 4 - - # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(f, ops_to_check=[Subtensor]) - - # let debugmode test something - f([[0, 1], [2, 3]], 4, [[4, 5], [6, 7]]) - - def test_basic_3(self): - # as 1, but take a slice - x = matrix("x") - y = scalar("y") - z = matrix("z") - f = function([x, y, z], exp(x + y + z)[0:2], mode=mode_opt) - - prog = f.maker.fgraph.toposort() - assert isinstance(prog[0].op, Subtensor) - assert isinstance(prog[1].op, DimShuffle) - assert isinstance(prog[2].op, Subtensor) - assert isinstance(prog[3].op.scalar_op, ps.Composite) # Composite{add,add} - assert len(prog) == 4 - - # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(f, ops_to_check=[Subtensor]) - - # let debugmode test something - f([[0, 1], [2, 3]], 4, [[4, 5], [6, 7]]) - - def test_basic_4(self): - # basic test that the optimization does work with broadcasting - # for unary elemwise. - y = vector("y") - f = function([y], exp(y.dimshuffle(0, "x"))[0], mode=mode_opt) - - # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(f, ops_to_check="all") - - prog = f.maker.fgraph.toposort() - assert isinstance(prog[0].op, Subtensor) - assert isinstance(prog[1].op, DimShuffle) - assert prog[2].op == exp - assert len(prog) == 3 - f([4, 5]) # let debugmode test something - - @utt.assertFailure_fast - def test_basic_5(self): - # basic test that the optimization doesn't work with broadcasting - # ... It *could* be extended to, - # ... but right now it doesn't, so it shouldn't try. - x = matrix("x") - y = vector("y") - f = function([x, y], exp(x + y)[0], mode=mode_opt) - - # Opt doesn't apply, so no need for check_stack_trace - # assert check_stack_trace(f, ops_to_check='all') - prog = f.maker.fgraph.toposort() - assert isinstance(prog[0].op, DimShuffle) - assert prog[1].op == add - assert isinstance(prog[2].op, Subtensor) # first subtensor - assert prog[3].op == inplace.exp_inplace - assert len(prog) == 4 - f([[0, 1], [2, 3]], [4, 5]) # let debugmode test something + x_test = [[0, 1], [2, 3]] + res1, res2 = f(x_test) + np.testing.assert_allclose( + res1, + np.exp(x_test)[0], + ) + np.testing.assert_allclose(res2, np.exp(x_test)) - def test_basic_6(self): + def test_multinary_multiple_clients(self): # test that we don't lift when we reuse the output of the # elemwise for other computation. x = matrix("x") @@ -181,84 +96,145 @@ def test_basic_6(self): # first subtensor assert isinstance(prog[2].op, Subtensor) assert len(prog) == 3 - f([[0, 1], [2, 3]], [4, 5]) # let debugmode test something - def test_basic_7(self): - # basic test that the optimization works with a scalar as input, - # and a scalar as output (no broadcasting of the scalar needed). - # The optimization used to fail and display an ERROR message. + x_test = np.array([[0, 1], [2, 3]]) + y_test = np.array([4, 5]) + res1, res2 = f(x_test, y_test) + np.testing.assert_allclose( + res1, + np.exp(x_test + y_test)[0], + ) + np.testing.assert_allclose( + res2, + np.exp(x_test + y_test) + x_test, + ) - x = vector("x") - y = scalar("y") - f = function([x, y], exp(x + y)[0], mode=mode_opt) - # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(f, ops_to_check=Subtensor) +@pytest.mark.parametrize( + "original_fn, expected_fn", + [ + # Unary integer indexing + (lambda x, y: exp(x)[0], lambda x, y: exp(x[0])), + # Unary integer with expand_dims + (lambda x, y: exp(x[:, None])[0], lambda x, y: exp(x[0][None])), + # Integer indexing on non-broadcastable dimension + (lambda x, y: add(x, y)[0], lambda x, y: add(x[0], y[0])), + # Slice indexing on non-broadcastable dimension + (lambda x, y: add(x, y)[1:], lambda x, y: add(x[1:], y[1:])), + # Integer indexing on broacastable dimension + (lambda x, y: add(x[None], y[None])[0], lambda x, y: add(x, y)), + (lambda x, y: add(x[None], y[None])[0, 1], lambda x, y: add(x[1], y[1])), + ( + lambda x, y: add(x[None, :], y[:, None])[2], + lambda x, y: add(x, y[2][None]), + ), + ( + lambda x, y: add(x[:, None], y[None, :])[:, 2], + lambda x, y: add(x, y[2][None]), + ), + # Slice indexing on broadcastable dimension + ( + lambda x, y: add(x[None], y[None])[1:], + lambda x, y: add(x[None][1:], y[None][1:]), + ), + ( + lambda x, y: add(x[None, :], y[:, None])[1:], + lambda x, y: add(x[None, :], y[1:][:, None]), + ), + ], +) +def test_local_subtensor_of_elemwise(self, original_fn, expected_fn): + rng = np.random.default_rng(257) + x = pt.matrix("x", shape=(5, 3)) + y = pt.matrix("y", shape=(5, 3)) + x_test = rng.normal(size=x.type.shape) + y_test = rng.normal(size=y.type.shape) - prog = f.maker.fgraph.toposort() - assert isinstance(prog[0].op, Subtensor) - # Composite{add,exp} - assert isinstance(prog[1].op.scalar_op, ps.Composite) - assert len(prog) == 2 - f([1, 2, 3], 4) # let debugmode test something - - def test_basic_8(self): - # Test that Subtensor(Unbroadcast(x)) gets optimized into - # Unbroadcast(Subtensor(x)). - - # test basic case - x = row("x") - xval = np.random.random((1, 10)).astype(config.floatX) - assert x.broadcastable == (True, False) - newx = Unbroadcast(0)(x) - assert newx.broadcastable == (False, False) - - f1 = function([x], newx[:2, :5], mode=mode_opt) - # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(f1, ops_to_check=[Subtensor, Unbroadcast]) - prog = f1.maker.fgraph.toposort() - assert isinstance(prog[0].op, Subtensor) - assert isinstance(prog[1].op, Unbroadcast) - assert (f1(xval) == xval[:2, :5]).all() - - # corner case 1: Unbroadcast changes dims which are dropped through subtensor - y = tensor(dtype="float64", shape=(1, 10, 1, 3), name="x") - yval = np.random.random((1, 10, 1, 3)).astype(config.floatX) - assert y.broadcastable == (True, False, True, False) - newy = Unbroadcast(0, 2)(y) - assert newy.broadcastable == (False, False, False, False) - - f2 = function([y], newy[:, 3, 0, :], mode=mode_opt) - # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(f2, ops_to_check=[Subtensor, Unbroadcast]) - prog = f2.maker.fgraph.toposort() - assert isinstance(prog[0].op, Subtensor) - assert isinstance(prog[1].op, Unbroadcast) - assert (f2(yval) == yval[:, 3, 0, :]).all() - - # corner case 2: subtensor idx_list is shorter than resulting broadcast pattern - f3 = function([y], newy[:, 3, 0], mode=mode_opt) - # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(f3, ops_to_check=[Subtensor, Unbroadcast]) - prog = f3.maker.fgraph.toposort() - assert isinstance(prog[0].op, Subtensor) - assert isinstance(prog[1].op, Unbroadcast) - assert (f3(yval) == yval[:, 3, 0]).all() - - # corner case 3: subtensor idx_list is shorter than Unbroadcast.axis - z = tensor(dtype="float64", shape=(4, 10, 3, 1), name="x") - zval = np.random.random((4, 10, 3, 1)).astype(config.floatX) - assert z.broadcastable == (False, False, False, True) - newz = Unbroadcast(3)(z) - assert newz.broadcastable == (False, False, False, False) - - f4 = function([z], newz[:, 3, 0], mode=mode_opt) - # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(f4, ops_to_check=[Subtensor, Unbroadcast]) - prog = f4.maker.fgraph.toposort() - assert isinstance(prog[0].op, Subtensor) - assert isinstance(prog[1].op, Unbroadcast) - assert (f4(zval) == zval[:, 3, 0]).all() + out = original_fn(x, y) + expected_opt_out = expected_fn(x, y) + opt_out = rewrite_graph(out) + assert equal_computations([opt_out], [expected_opt_out]), debugprint( + [expected_opt_out, opt_out], print_type=True + ) + eval_kwargs = dict(mode=NO_OPTIMIZATION_MODE, on_unused_input="ignore") + np.testing.assert_allclose( + opt_out.eval({x: x_test, y: y_test}, **eval_kwargs), + out.eval({x: x_test, y: y_test}, **eval_kwargs), + ) + + +def test_local_subtensor_of_elemwise_multiple_clients(): + x = pt.matrix("x", shape=(5, 3)) + y = pt.matrix("y", shape=(5, 3)) + out1 = add(x, y) + out2 = out1[0] + + # Rewrite should fail when another node uses out1 directly (in this case it's an extra output) + fgraph = FunctionGraph([x, y], [out1, out2], clone=False) + assert local_subtensor_of_elemwise.transform(fgraph, out2.owner) is None + + # Otherwise it should work + fgraph.remove_output(0) + assert local_subtensor_of_elemwise.transform(fgraph, out2.owner) is not None + + +def test_local_subtensor_of_unbroadcast(): + # Test that Subtensor(Unbroadcast(x)) gets optimized into + # Unbroadcast(Subtensor(x)). + + # test basic case + x = row("x") + xval = np.random.random((1, 10)).astype(config.floatX) + assert x.broadcastable == (True, False) + newx = Unbroadcast(0)(x) + assert newx.broadcastable == (False, False) + + f1 = function([x], newx[:2, :5], mode=mode_opt) + # Check stacktrace was copied over correctly after opt was applied + assert check_stack_trace(f1, ops_to_check=[Subtensor, Unbroadcast]) + prog = f1.maker.fgraph.toposort() + assert isinstance(prog[0].op, Subtensor) + assert isinstance(prog[1].op, Unbroadcast) + assert (f1(xval) == xval[:2, :5]).all() + + # corner case 1: Unbroadcast changes dims which are dropped through subtensor + y = tensor(dtype="float64", shape=(1, 10, 1, 3), name="x") + yval = np.random.random((1, 10, 1, 3)).astype(config.floatX) + assert y.broadcastable == (True, False, True, False) + newy = Unbroadcast(0, 2)(y) + assert newy.broadcastable == (False, False, False, False) + + f2 = function([y], newy[:, 3, 0, :], mode=mode_opt) + # Check stacktrace was copied over correctly after opt was applied + assert check_stack_trace(f2, ops_to_check=[Subtensor, Unbroadcast]) + prog = f2.maker.fgraph.toposort() + assert isinstance(prog[0].op, Subtensor) + assert isinstance(prog[1].op, Unbroadcast) + assert (f2(yval) == yval[:, 3, 0, :]).all() + + # corner case 2: subtensor idx_list is shorter than resulting broadcast pattern + f3 = function([y], newy[:, 3, 0], mode=mode_opt) + # Check stacktrace was copied over correctly after opt was applied + assert check_stack_trace(f3, ops_to_check=[Subtensor, Unbroadcast]) + prog = f3.maker.fgraph.toposort() + assert isinstance(prog[0].op, Subtensor) + assert isinstance(prog[1].op, Unbroadcast) + assert (f3(yval) == yval[:, 3, 0]).all() + + # corner case 3: subtensor idx_list is shorter than Unbroadcast.axis + z = tensor(dtype="float64", shape=(4, 10, 3, 1), name="x") + zval = np.random.random((4, 10, 3, 1)).astype(config.floatX) + assert z.broadcastable == (False, False, False, True) + newz = Unbroadcast(3)(z) + assert newz.broadcastable == (False, False, False, False) + + f4 = function([z], newz[:, 3, 0], mode=mode_opt) + # Check stacktrace was copied over correctly after opt was applied + assert check_stack_trace(f4, ops_to_check=[Subtensor, Unbroadcast]) + prog = f4.maker.fgraph.toposort() + assert isinstance(prog[0].op, Subtensor) + assert isinstance(prog[1].op, Unbroadcast) + assert (f4(zval) == zval[:, 3, 0]).all() @pytest.mark.parametrize( From ffcfa7d4acc979a247555e1bf7543408bbb8a687 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 20 Jan 2025 14:27:21 +0100 Subject: [PATCH 07/10] Lift Subtensor over CAReduce --- pytensor/tensor/rewriting/subtensor_lift.py | 60 ++++++++++++++++++- tests/tensor/rewriting/test_subtensor_lift.py | 35 +++++++++++ 2 files changed, 94 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/subtensor_lift.py b/pytensor/tensor/rewriting/subtensor_lift.py index 465a73f6dd..7a4a9f3216 100644 --- a/pytensor/tensor/rewriting/subtensor_lift.py +++ b/pytensor/tensor/rewriting/subtensor_lift.py @@ -1,6 +1,7 @@ from collections.abc import Iterable, Sequence import numpy as np +from numpy.core.numeric import normalize_axis_tuple # type: ignore from pytensor import Variable from pytensor.graph import Constant, node_rewriter @@ -15,7 +16,7 @@ get_underlying_scalar_constant_value, register_infer_shape, ) -from pytensor.tensor.elemwise import DimShuffle, Elemwise +from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.extra_ops import squeeze from pytensor.tensor.math import Dot, ceil_intdiv, dot @@ -185,6 +186,63 @@ def local_subtensor_of_elemwise(fgraph, node): return [new_out] +@register_canonicalize +@register_specialize +@node_rewriter([Subtensor]) +def local_subtensor_of_reduce(fgraph, node): + """Lift a Subtensor through a CAReduce Op. + + For now rewrite is restricted to single axis of reduction, for simplicity. + + sum(x, axis=1)[0] -> sum(x[0], axis=0) + sum(x, axis=1)[1:] -> sum(x[1:], axis=1) + sum(x, axis=0)[0] -> sum(x[:, 0], axis=0) + sum(x, axis=0)[1:] -> sum(x[:, 1:], axis=0) + + """ + red, *idx = node.inputs + + if not (red.owner and isinstance(red.owner.op, CAReduce)): + return None + + if len(fgraph.clients[red]) > 1: + # Don't apply rewrite if another node requires the full reduction + return None + + [x] = red.owner.inputs + axis = red.owner.op.axis + + if axis is None: + axis = tuple(range(x.type.ndim)) + + # TODO: Allow reduction across multiple axis + if len(axis) != 1: + return None + + [axis] = normalize_axis_tuple(axis, x.ndim) + idx_tuple = indices_from_subtensor(idx, node.op.idx_list) + + # Index input of reduction. + new_idxs = list(idx_tuple) + if axis < len(idx_tuple): + # When there are indexes beyond the axis of reduction, we need to shift them with None slices. + new_idxs.insert(axis, slice(None)) + x_sub = x[tuple(new_idxs)] + + [old_out] = node.outputs + copy_stack_trace(old_out, x_sub) + + # Adjust axis of reduction when indexing drops dimensions (integer indexing as apposed to slice indexing) + axis -= len( + [idx_item for idx_item in idx_tuple[:axis] if not isinstance(idx_item, slice)] + ) + + # Apply reduction to indexed input + out = type(red.owner.op)(axis=axis)(x_sub) + copy_stack_trace(old_out, out) + return [out] + + @register_canonicalize("shape_unsafe") @register_specialize("shape_unsafe") @node_rewriter([Subtensor]) diff --git a/tests/tensor/rewriting/test_subtensor_lift.py b/tests/tensor/rewriting/test_subtensor_lift.py index a4622e363d..0b087d1b12 100644 --- a/tests/tensor/rewriting/test_subtensor_lift.py +++ b/tests/tensor/rewriting/test_subtensor_lift.py @@ -44,6 +44,7 @@ ) from pytensor.tensor.basic import MakeVector, expand_dims, make_vector from pytensor.tensor.elemwise import DimShuffle, Elemwise +from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.rewriting.subtensor_lift import ( local_subtensor_make_vector, local_subtensor_of_elemwise, @@ -178,6 +179,40 @@ def test_local_subtensor_of_elemwise_multiple_clients(): assert local_subtensor_of_elemwise.transform(fgraph, out2.owner) is not None +@pytest.mark.parametrize( + "original_fn, expected_fn", + [ + # Indexing before axis of reduction + (lambda x: pt_sum(x, axis=2)[0], lambda x: pt_sum(x[0], axis=1)), + (lambda x: pt_sum(x, axis=2)[0, 1], lambda x: pt_sum(x[0, 1], axis=None)), + (lambda x: pt_sum(x, axis=2)[1:], lambda x: pt_sum(x[1:], axis=2)), + # Indexing "at" axis of reduction + (lambda x: pt_sum(x, axis=0)[2], lambda x: pt_sum(x[:, 2], axis=0)), + (lambda x: pt_sum(x, axis=0)[:-2], lambda x: pt_sum(x[:, :-2], axis=0)), + # Index after axis of reduction + (lambda x: pt_sum(x, axis=0)[:, 1:], lambda x: pt_sum(x[:, :, 1:], axis=0)), + # Index before and after axis reduction + (lambda x: pt_sum(x, axis=1)[-2, 1:], lambda x: pt_sum(x[-2, :, 1:], axis=0)), + (lambda x: pt_sum(x, axis=1)[1:, -2], lambda x: pt_sum(x[1:, :, -2], axis=1)), + ], +) +def test_local_subtensor_of_reduce(original_fn, expected_fn): + rng = np.random.default_rng(245) + x = pt.tensor("x", shape=(5, 3, 2)) + x_test = rng.normal(size=x.type.shape) + + out = original_fn(x) + expected_opt_out = expected_fn(x) + opt_out = rewrite_graph(out) + assert equal_computations([opt_out], [expected_opt_out]), debugprint( + [expected_opt_out, opt_out], print_type=True + ) + np.testing.assert_allclose( + opt_out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE), + out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE), + ) + + def test_local_subtensor_of_unbroadcast(): # Test that Subtensor(Unbroadcast(x)) gets optimized into # Unbroadcast(Subtensor(x)). From 844ae152bef2cd380a3c1401f1601521c5b32d68 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 20 Jan 2025 16:52:17 +0100 Subject: [PATCH 08/10] Lift Subtensor over Softmax --- pytensor/tensor/rewriting/subtensor_lift.py | 98 ++++++++++++++++++- tests/tensor/rewriting/test_subtensor_lift.py | 39 ++++++++ 2 files changed, 136 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/subtensor_lift.py b/pytensor/tensor/rewriting/subtensor_lift.py index 7a4a9f3216..b8a7374f7c 100644 --- a/pytensor/tensor/rewriting/subtensor_lift.py +++ b/pytensor/tensor/rewriting/subtensor_lift.py @@ -1,7 +1,10 @@ from collections.abc import Iterable, Sequence import numpy as np -from numpy.core.numeric import normalize_axis_tuple # type: ignore +from numpy.core.numeric import ( # type: ignore + normalize_axis_index, + normalize_axis_tuple, +) from pytensor import Variable from pytensor.graph import Constant, node_rewriter @@ -34,6 +37,7 @@ specify_shape, unbroadcast, ) +from pytensor.tensor.special import Softmax, softmax from pytensor.tensor.subtensor import ( AdvancedSubtensor1, Subtensor, @@ -53,6 +57,20 @@ def _dims_dropped_by_basic_index(idxs: Sequence[slice | int]) -> tuple[int, ...] return tuple(i for i, idx in enumerate(idxs) if not isinstance(idx, slice)) +def _ndim_dropped_left_of_axis_by_basic_index( + idxs: Sequence[slice | int], axis: int +) -> int: + return len(_dims_dropped_by_basic_index(idxs[:axis])) + + +def _axis_is_indexed_by_basic_index( + idxs: Sequence[slice | int], axis: int | Sequence[int] +) -> bool: + if isinstance(axis, int): + axis = (axis,) + return any(ax < len(idxs) and not is_full_slice(idxs[ax]) for ax in axis) + + @register_canonicalize @register_stabilize @register_specialize @@ -243,6 +261,84 @@ def local_subtensor_of_reduce(fgraph, node): return [out] +@register_canonicalize +@register_specialize +@node_rewriter([Subtensor]) +def local_subtensor_of_softmax(fgraph, node): + """Lift a Subtensor through a Softmax. + + softmax(x, axis=1)[0] -> softmax(x[0], axis=0) + softmax(x, axis=1)[:, :, 0] -> softmax(x[:, :, 0], axis=1) + + If part of the indexing acts on the axis of reduction, we split it + softmax(x, axis=1)[:, 0, 1:] -> softmax(x[:, :, 1:], axis=1)[0] + + """ + sm, *idx = node.inputs + + if not (sm.owner and isinstance(sm.owner.op, Softmax)): + return None + + if len(fgraph.clients[sm]) > 1: + return None + + [x] = sm.owner.inputs + axis = sm.owner.op.axis + + if axis is None: + if x.type.ndim == 1: + axis = 0 + else: + # All dimensions are mixed, we can't lift the subtensor + return None + else: + # Softmax currently only allows None or a single integer axis + # Unlike CAReduce it does not normalize negative indices + axis = normalize_axis_index(axis, sm.ndim) + + [old_out] = node.outputs + idx_tuple = indices_from_subtensor(idx, node.op.idx_list) + + if _axis_is_indexed_by_basic_index(idx_tuple, axis): + # If there are more dimensions being indexed, we can split them + # And lift the non-axis indexes while keeping the axis index + real_indices = [idx for idx in idx_tuple if not is_full_slice(idx)] + if len(real_indices) > 1 and sm.type.ndim > 1: + # Split the subtensor + idx_to_keep = idx_tuple[axis] + idxs_to_lift = (*idx_tuple[:axis], slice(None), *idx_tuple[axis + 1 :]) + + # Lift the non-axis indexes by calling the rewrite itself + opt_sm = sm[idxs_to_lift] + [opt_sm] = local_subtensor_of_softmax.transform(fgraph, opt_sm.owner) + copy_stack_trace([old_out, sm], opt_sm) + + # Then reintroduce the axis index + ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index( + idx_tuple, axis + ) + new_axis = axis - ndim_reduced_left + idxs_to_keep = (*(slice(None),) * new_axis, idx_to_keep) + new_out = opt_sm[idxs_to_keep] + copy_stack_trace(old_out, new_out) + return [new_out] + + else: + return None + + # Index input to softmax + x_sub = x[idx_tuple] + + # Adjust axis of reduction when indexing drops dimensions (integer indexing as apposed to slice indexing) + axis -= len( + [idx_item for idx_item in idx_tuple[:axis] if not isinstance(idx_item, slice)] + ) + + out = softmax(x_sub, axis=axis) + copy_stack_trace(old_out, out) + return [out] + + @register_canonicalize("shape_unsafe") @register_specialize("shape_unsafe") @node_rewriter([Subtensor]) diff --git a/tests/tensor/rewriting/test_subtensor_lift.py b/tests/tensor/rewriting/test_subtensor_lift.py index 0b087d1b12..6c15589401 100644 --- a/tests/tensor/rewriting/test_subtensor_lift.py +++ b/tests/tensor/rewriting/test_subtensor_lift.py @@ -51,6 +51,7 @@ local_subtensor_shape_constant, ) from pytensor.tensor.shape import SpecifyShape, Unbroadcast, _shape +from pytensor.tensor.special import softmax from pytensor.tensor.subtensor import Subtensor @@ -213,6 +214,44 @@ def test_local_subtensor_of_reduce(original_fn, expected_fn): ) +@pytest.mark.parametrize( + "original_fn, expected_fn", + [ + # Lift single index that does not ovelap with axis of softmax + (lambda x: softmax(x, axis=1)[0], lambda x: softmax(x[0], axis=0)), + (lambda x: softmax(x, axis=1)[1:], lambda x: softmax(x[1:], axis=1)), + (lambda x: softmax(x, axis=0)[:, 0], lambda x: softmax(x[:, 0], axis=0)), + (lambda x: softmax(x, axis=0)[:, 1:], lambda x: softmax(x[:, 1:], axis=0)), + # Do nothing to single index over axis of softmax + (lambda x: softmax(x, axis=0)[0], lambda x: softmax(x, axis=0)[0]), + (lambda x: softmax(x, axis=1)[:, 1:], lambda x: softmax(x, axis=1)[:, 1:]), + # Split indexing on axis of softmax + (lambda x: softmax(x, axis=0)[1:, 0], lambda x: softmax(x[:, 0], axis=0)[1:]), + (lambda x: softmax(x, axis=1)[1:, 0], lambda x: softmax(x[1:], axis=1)[:, 0]), + ( + lambda x: softmax(x, axis=0)[0, :5:2], + lambda x: softmax(x[:, :5:2], axis=0)[0], + ), + (lambda x: softmax(x, axis=1)[0, :5:2], lambda x: softmax(x[0], axis=0)[:5:2]), + ], +) +def test_local_subtensor_of_softmax(original_fn, expected_fn): + rng = np.random.default_rng(230) + x = pt.matrix("x", shape=(5, 3)) + x_test = rng.normal(size=x.type.shape) + + out = original_fn(x) + expected_opt_out = expected_fn(x) + opt_out = rewrite_graph(out) + assert equal_computations([opt_out], [expected_opt_out]), debugprint( + [expected_opt_out, opt_out], print_type=True + ) + np.testing.assert_allclose( + opt_out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE), + out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE), + ) + + def test_local_subtensor_of_unbroadcast(): # Test that Subtensor(Unbroadcast(x)) gets optimized into # Unbroadcast(Subtensor(x)). From 508d5b91713a1f5ed10f88444191f6d021acf059 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 20 Jan 2025 17:22:53 +0100 Subject: [PATCH 09/10] Lift Subtensor over Join --- pytensor/tensor/rewriting/subtensor_lift.py | 120 ++++++++++++++---- tests/tensor/rewriting/test_subtensor_lift.py | 56 +++++++- 2 files changed, 150 insertions(+), 26 deletions(-) diff --git a/pytensor/tensor/rewriting/subtensor_lift.py b/pytensor/tensor/rewriting/subtensor_lift.py index b8a7374f7c..961967d0d0 100644 --- a/pytensor/tensor/rewriting/subtensor_lift.py +++ b/pytensor/tensor/rewriting/subtensor_lift.py @@ -1,4 +1,5 @@ from collections.abc import Iterable, Sequence +from typing import cast import numpy as np from numpy.core.numeric import ( # type: ignore @@ -7,16 +8,18 @@ ) from pytensor import Variable -from pytensor.graph import Constant, node_rewriter -from pytensor.graph.rewriting.basic import copy_stack_trace +from pytensor.graph import Constant, FunctionGraph, node_rewriter +from pytensor.graph.rewriting.basic import NodeRewriter, copy_stack_trace from pytensor.scalar import basic as ps from pytensor.tensor.basic import ( Alloc, + Join, MakeVector, alloc, as_tensor, expand_dims, get_underlying_scalar_constant_value, + join, register_infer_shape, ) from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise @@ -49,6 +52,7 @@ ) from pytensor.tensor.type import TensorType from pytensor.tensor.type_other import SliceType +from pytensor.tensor.variable import TensorVariable def _dims_dropped_by_basic_index(idxs: Sequence[slice | int]) -> tuple[int, ...]: @@ -71,6 +75,41 @@ def _axis_is_indexed_by_basic_index( return any(ax < len(idxs) and not is_full_slice(idxs[ax]) for ax in axis) +def _lift_subtensor_non_axis( + local_subtensor_lift_rewrite: NodeRewriter, + fgraph: FunctionGraph, + variable: TensorVariable, + idx_tuple: tuple[int | slice], + axis: int, + old_subtensor_variable: TensorVariable, +) -> None | list[TensorVariable]: + # Apply generic subtensor lift rewrite along "non-axis" dimensions + real_indices = [idx for idx in idx_tuple if not is_full_slice(idx)] + if len(real_indices) > 1 and variable.type.ndim > 1: + # Split the subtensor + idx_to_keep = idx_tuple[axis] + idxs_to_lift = (*idx_tuple[:axis], slice(None), *idx_tuple[axis + 1 :]) + + # Lift the non-axis indexes by calling the rewrite itself + indexed_variable = variable[idxs_to_lift] + [indexed_variable] = cast( + list[TensorVariable], + local_subtensor_lift_rewrite.transform(fgraph, indexed_variable.owner), + ) + copy_stack_trace([old_subtensor_variable, indexed_variable], indexed_variable) + + # Then reintroduce the axis index + ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index(idx_tuple, axis) + new_axis = axis - ndim_reduced_left + idxs_to_keep = (*(slice(None),) * new_axis, idx_to_keep) + new_out = indexed_variable[idxs_to_keep] + copy_stack_trace(old_subtensor_variable, new_out) + return [new_out] + + else: + return None + + @register_canonicalize @register_stabilize @register_specialize @@ -302,29 +341,14 @@ def local_subtensor_of_softmax(fgraph, node): if _axis_is_indexed_by_basic_index(idx_tuple, axis): # If there are more dimensions being indexed, we can split them # And lift the non-axis indexes while keeping the axis index - real_indices = [idx for idx in idx_tuple if not is_full_slice(idx)] - if len(real_indices) > 1 and sm.type.ndim > 1: - # Split the subtensor - idx_to_keep = idx_tuple[axis] - idxs_to_lift = (*idx_tuple[:axis], slice(None), *idx_tuple[axis + 1 :]) - - # Lift the non-axis indexes by calling the rewrite itself - opt_sm = sm[idxs_to_lift] - [opt_sm] = local_subtensor_of_softmax.transform(fgraph, opt_sm.owner) - copy_stack_trace([old_out, sm], opt_sm) - - # Then reintroduce the axis index - ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index( - idx_tuple, axis - ) - new_axis = axis - ndim_reduced_left - idxs_to_keep = (*(slice(None),) * new_axis, idx_to_keep) - new_out = opt_sm[idxs_to_keep] - copy_stack_trace(old_out, new_out) - return [new_out] - - else: - return None + return _lift_subtensor_non_axis( + local_subtensor_lift_rewrite=local_subtensor_of_softmax, + fgraph=fgraph, + variable=sm, + idx_tuple=idx_tuple, + axis=axis, + old_subtensor_variable=old_out, + ) # Index input to softmax x_sub = x[idx_tuple] @@ -695,6 +719,52 @@ def local_subtensor_make_vector(fgraph, node): pass +@register_canonicalize +@register_specialize +@node_rewriter([Subtensor]) +def local_subtensor_of_join(fgraph, node): + """Lift a Subtensor through a Join. + + join(axis=1, x, y)[0] -> join(axis=0, x[0], y[0]) + join(axis=1, x, y)[:, 0, -1] -> join(axis=1, x[:, :, -1], y[:, :, -1])[:, 0] + + """ + join_var, *idx = node.inputs + + if not (join_var.owner and isinstance(join_var.owner.op, Join)): + return None + + if len(fgraph.clients[join_var]) > 1: + # Join involves a full_copy, so we don't want to do it twice + return None + + join_axis, *join_components = join_var.owner.inputs + + # Rewrite only works when the join axis is a constant along a non-indexed dimension + if not isinstance(join_axis, Constant): + return None + + [old_out] = node.outputs + axis = normalize_axis_index(join_axis.data, join_components[0].type.ndim) + idx_tuple = indices_from_subtensor(idx, node.op.idx_list) + if _axis_is_indexed_by_basic_index(idx_tuple, axis): + return _lift_subtensor_non_axis( + local_subtensor_lift_rewrite=local_subtensor_of_join, + fgraph=fgraph, + variable=join_var, + idx_tuple=idx_tuple, + axis=axis, + old_subtensor_variable=old_out, + ) + + # Lift index to the Join components + indexed_components = [component[idx_tuple] for component in join_components] + new_axis = axis - _ndim_dropped_left_of_axis_by_basic_index(idx_tuple, axis) + out = join(new_axis, *indexed_components) + + return [out] + + @register_specialize @register_canonicalize @node_rewriter([Subtensor]) diff --git a/tests/tensor/rewriting/test_subtensor_lift.py b/tests/tensor/rewriting/test_subtensor_lift.py index 6c15589401..78e529178e 100644 --- a/tests/tensor/rewriting/test_subtensor_lift.py +++ b/tests/tensor/rewriting/test_subtensor_lift.py @@ -42,7 +42,7 @@ tensor3, vector, ) -from pytensor.tensor.basic import MakeVector, expand_dims, make_vector +from pytensor.tensor.basic import MakeVector, concatenate, expand_dims, make_vector from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.rewriting.subtensor_lift import ( @@ -252,6 +252,9 @@ def test_local_subtensor_of_softmax(original_fn, expected_fn): ) +shared_axis = shared(1, "axis") + + def test_local_subtensor_of_unbroadcast(): # Test that Subtensor(Unbroadcast(x)) gets optimized into # Unbroadcast(Subtensor(x)). @@ -661,6 +664,57 @@ def test_empty_subtensor(self): assert local_subtensor_make_vector.transform(fgraph, node) == [v] +@pytest.mark.parametrize( + "original_fn, expected_fn", + [ + ( + lambda x, y: concatenate([x, y], axis=1)[1], + lambda x, y: concatenate([x[1], y[1]], axis=0), + ), + ( + lambda x, y: concatenate([x, y], axis=-1)[1:], + lambda x, y: concatenate([x[1:], y[1:]], axis=1), + ), + # Indexing on both axis of concatenation and somewhere else: + ( + lambda x, y: concatenate([x, y], axis=1)[0, 1:], + lambda x, y: concatenate([x[0], y[0]], axis=0)[1:], + ), + # Not supported, indexing on axis of concatenation + ( + lambda x, y: concatenate([x, y], axis=0)[0], + lambda x, y: concatenate([x, y], axis=0)[0], + ), + ( + lambda x, y: concatenate([x, y], axis=1)[:, 1:], + lambda x, y: concatenate([x, y], axis=1)[:, 1:], + ), + # Not supported, axis of concatenation is dynamically determined + ( + lambda x, y: concatenate([x, y], axis=shared_axis)[1], + lambda x, y: concatenate([x, y], axis=shared_axis)[1], + ), + ], +) +def test_local_subtensor_of_join(original_fn, expected_fn): + rng = np.random.default_rng(257) + x = pt.matrix("x", shape=(5, 3)) + y = pt.matrix("y", shape=(5, 3)) + x_test = rng.normal(size=x.type.shape) + y_test = rng.normal(size=y.type.shape) + + out = original_fn(x, y) + expected_opt_out = expected_fn(x, y) + opt_out = rewrite_graph(out) + assert equal_computations([opt_out], [expected_opt_out]), debugprint( + [expected_opt_out, opt_out], print_type=True + ) + np.testing.assert_allclose( + opt_out.eval({x: x_test, y: y_test}, mode=NO_OPTIMIZATION_MODE), + out.eval({x: x_test, y: y_test}, mode=NO_OPTIMIZATION_MODE), + ) + + def test_local_subtensor_shape_constant(): x = tensor(dtype=np.float64, shape=(1, None)).shape[0] (res,) = local_subtensor_shape_constant.transform(None, x.owner) From d1b5784d8efa08a9e531e4d1878687d64216e5e6 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 20 Jan 2025 18:08:00 +0100 Subject: [PATCH 10/10] Lift Subtensor over AdvancedSubtensor --- pytensor/tensor/rewriting/subtensor_lift.py | 81 ++++++++++++++++++- tests/tensor/rewriting/test_subtensor_lift.py | 46 ++++++++++- 2 files changed, 125 insertions(+), 2 deletions(-) diff --git a/pytensor/tensor/rewriting/subtensor_lift.py b/pytensor/tensor/rewriting/subtensor_lift.py index 961967d0d0..bc34b19d4d 100644 --- a/pytensor/tensor/rewriting/subtensor_lift.py +++ b/pytensor/tensor/rewriting/subtensor_lift.py @@ -8,6 +8,7 @@ ) from pytensor import Variable +from pytensor.compile import optdb from pytensor.graph import Constant, FunctionGraph, node_rewriter from pytensor.graph.rewriting.basic import NodeRewriter, copy_stack_trace from pytensor.scalar import basic as ps @@ -42,8 +43,10 @@ ) from pytensor.tensor.special import Softmax, softmax from pytensor.tensor.subtensor import ( + AdvancedSubtensor, AdvancedSubtensor1, Subtensor, + _non_contiguous_adv_indexing, as_index_literal, get_canonical_form_slice, get_constant_idx, @@ -51,7 +54,7 @@ indices_from_subtensor, ) from pytensor.tensor.type import TensorType -from pytensor.tensor.type_other import SliceType +from pytensor.tensor.type_other import NoneTypeT, SliceType from pytensor.tensor.variable import TensorVariable @@ -818,3 +821,79 @@ def local_subtensor_shape_constant(fgraph, node): return [as_tensor([1] * len(shape_parts), dtype=np.int64, ndim=1)] elif shape_parts: return [as_tensor(1, dtype=np.int64)] + + +@node_rewriter([Subtensor]) +def local_subtensor_of_adv_subtensor(fgraph, node): + """Lift a simple Subtensor through an AdvancedSubtensor, when basic index dimensions are to the left of any advanced ones. + + x[:, :, vec_idx][i, j] -> x[i, j][vec_idx] + x[:, vec_idx][i, j, k] -> x[i][vec_idx][j, k] + + Restricted to a single advanced indexing dimension. + + An alternative approach could have fused the basic and advanced indices, + so it is not clear this rewrite should be canonical or a specialization. + Users must include it manually if it fits their use case. + """ + adv_subtensor, *idxs = node.inputs + + if not ( + adv_subtensor.owner and isinstance(adv_subtensor.owner.op, AdvancedSubtensor) + ): + return None + + if len(fgraph.clients[adv_subtensor]) > 1: + # AdvancedSubtensor involves a full_copy, so we don't want to do it twice + return None + + x, *adv_idxs = adv_subtensor.owner.inputs + + # Advanced indexing is a minefield, avoid all cases except for consecutive integer indices + if any( + ( + isinstance(adv_idx.type, NoneTypeT) + or (isinstance(adv_idx.type, TensorType) and adv_idx.type.dtype == "bool") + or (isinstance(adv_idx.type, SliceType) and not is_full_slice(adv_idx)) + ) + for adv_idx in adv_idxs + ) or _non_contiguous_adv_indexing(adv_idxs): + return None + + for first_adv_idx_dim, adv_idx in enumerate(adv_idxs): + # We already made sure there were only None slices besides integer indexes + if isinstance(adv_idx.type, TensorType): + break + else: # no-break + # Not sure if this should ever happen, but better safe than sorry + return None + + basic_idxs = indices_from_subtensor(idxs, node.op.idx_list) + basic_idxs_lifted = basic_idxs[:first_adv_idx_dim] + basic_idxs_kept = ((slice(None),) * len(basic_idxs_lifted)) + basic_idxs[ + first_adv_idx_dim: + ] + + if all(basic_idx == slice(None) for basic_idx in basic_idxs_lifted): + # All basic indices happen to the right of the advanced indices + return None + + [basic_subtensor] = node.outputs + dropped_dims = _dims_dropped_by_basic_index(basic_idxs_lifted) + + x_indexed = x[basic_idxs_lifted] + copy_stack_trace([basic_subtensor, adv_subtensor], x_indexed) + + x_after_index_lift = expand_dims(x_indexed, dropped_dims) + x_after_adv_idx = adv_subtensor.owner.op(x_after_index_lift, *adv_idxs) + copy_stack_trace([basic_subtensor, adv_subtensor], x_after_adv_idx) + + new_out = squeeze(x_after_adv_idx[basic_idxs_kept], dropped_dims) + return [new_out] + + +# Rewrite will only be included if tagged by name +r = local_subtensor_of_adv_subtensor +optdb["canonicalize"].register(r.__name__, r, use_db_name_as_tag=False) +optdb["specialize"].register(r.__name__, r, use_db_name_as_tag=False) +del r diff --git a/tests/tensor/rewriting/test_subtensor_lift.py b/tests/tensor/rewriting/test_subtensor_lift.py index 78e529178e..e02fdc1083 100644 --- a/tests/tensor/rewriting/test_subtensor_lift.py +++ b/tests/tensor/rewriting/test_subtensor_lift.py @@ -52,7 +52,7 @@ ) from pytensor.tensor.shape import SpecifyShape, Unbroadcast, _shape from pytensor.tensor.special import softmax -from pytensor.tensor.subtensor import Subtensor +from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor NO_OPTIMIZATION_MODE = Mode(linker="py", optimizer=None) @@ -756,3 +756,47 @@ def __eq__(self, other): x = shape(Variable(MyType(), None, None))[0] assert not local_subtensor_shape_constant.transform(None, x.owner) + + +@pytest.mark.parametrize( + "original_fn, supported", + [ + (lambda x: x[:, [0, 1]][0], True), + (lambda x: x[:, [0, 1], [0, 0]][1:], True), + (lambda x: x[:, [[0, 1], [0, 0]]][1:], True), + # Not supported, basic indexing on advanced indexing dim + (lambda x: x[[0, 1]][0], False), + # Not implemented, basic indexing on the right of advanced indexing + (lambda x: x[[0, 1]][:, 0], False), + # Not implemented, complex flavors of advanced indexing + (lambda x: x[:, None, [0, 1]][0], False), + (lambda x: x[:, 5:, [0, 1]][0], False), + (lambda x: x[:, :, np.array([True, False, False])][0], False), + (lambda x: x[[0, 1], :, [0, 1]][:, 0], False), + ], +) +def test_local_subtensor_of_adv_subtensor(original_fn, supported): + rng = np.random.default_rng(257) + x = pt.tensor3("x", shape=(7, 5, 3)) + x_test = rng.normal(size=x.type.shape) + + out = original_fn(x) + opt_out = rewrite_graph( + out, include=("canonicalize", "local_subtensor_of_adv_subtensor") + ) + # The graphs generated are too complicated to assert + # We simply check that the happens before the advanced subtensor + toposort = FunctionGraph(outputs=[opt_out], clone=False).toposort() + [idx_subtensor] = [ + i for i, node in enumerate(toposort) if isinstance(node.op, Subtensor) + ] + [idx_adv_subtensor] = [ + i for i, node in enumerate(toposort) if isinstance(node.op, AdvancedSubtensor) + ] + swapped = idx_subtensor < idx_adv_subtensor + correct = swapped if supported else not swapped + assert correct, debugprint(opt_out, print_type=True) + np.testing.assert_allclose( + opt_out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE), + out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE), + )