Skip to content

Commit

Permalink
Lift Subtensor over AdvancedSubtensor
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jan 21, 2025
1 parent 508d5b9 commit d1b5784
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 2 deletions.
81 changes: 80 additions & 1 deletion pytensor/tensor/rewriting/subtensor_lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
)

from pytensor import Variable
from pytensor.compile import optdb
from pytensor.graph import Constant, FunctionGraph, node_rewriter
from pytensor.graph.rewriting.basic import NodeRewriter, copy_stack_trace
from pytensor.scalar import basic as ps
Expand Down Expand Up @@ -42,16 +43,18 @@
)
from pytensor.tensor.special import Softmax, softmax
from pytensor.tensor.subtensor import (
AdvancedSubtensor,
AdvancedSubtensor1,
Subtensor,
_non_contiguous_adv_indexing,
as_index_literal,
get_canonical_form_slice,
get_constant_idx,
get_idx_list,
indices_from_subtensor,
)
from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import SliceType
from pytensor.tensor.type_other import NoneTypeT, SliceType
from pytensor.tensor.variable import TensorVariable


Expand Down Expand Up @@ -818,3 +821,79 @@ def local_subtensor_shape_constant(fgraph, node):
return [as_tensor([1] * len(shape_parts), dtype=np.int64, ndim=1)]
elif shape_parts:
return [as_tensor(1, dtype=np.int64)]


@node_rewriter([Subtensor])
def local_subtensor_of_adv_subtensor(fgraph, node):
"""Lift a simple Subtensor through an AdvancedSubtensor, when basic index dimensions are to the left of any advanced ones.
x[:, :, vec_idx][i, j] -> x[i, j][vec_idx]
x[:, vec_idx][i, j, k] -> x[i][vec_idx][j, k]
Restricted to a single advanced indexing dimension.
An alternative approach could have fused the basic and advanced indices,
so it is not clear this rewrite should be canonical or a specialization.
Users must include it manually if it fits their use case.
"""
adv_subtensor, *idxs = node.inputs

if not (
adv_subtensor.owner and isinstance(adv_subtensor.owner.op, AdvancedSubtensor)
):
return None

if len(fgraph.clients[adv_subtensor]) > 1:
# AdvancedSubtensor involves a full_copy, so we don't want to do it twice
return None

x, *adv_idxs = adv_subtensor.owner.inputs

# Advanced indexing is a minefield, avoid all cases except for consecutive integer indices
if any(
(
isinstance(adv_idx.type, NoneTypeT)
or (isinstance(adv_idx.type, TensorType) and adv_idx.type.dtype == "bool")
or (isinstance(adv_idx.type, SliceType) and not is_full_slice(adv_idx))
)
for adv_idx in adv_idxs
) or _non_contiguous_adv_indexing(adv_idxs):
return None

for first_adv_idx_dim, adv_idx in enumerate(adv_idxs):
# We already made sure there were only None slices besides integer indexes
if isinstance(adv_idx.type, TensorType):
break
else: # no-break
# Not sure if this should ever happen, but better safe than sorry
return None

basic_idxs = indices_from_subtensor(idxs, node.op.idx_list)
basic_idxs_lifted = basic_idxs[:first_adv_idx_dim]
basic_idxs_kept = ((slice(None),) * len(basic_idxs_lifted)) + basic_idxs[
first_adv_idx_dim:
]

if all(basic_idx == slice(None) for basic_idx in basic_idxs_lifted):
# All basic indices happen to the right of the advanced indices
return None

[basic_subtensor] = node.outputs
dropped_dims = _dims_dropped_by_basic_index(basic_idxs_lifted)

x_indexed = x[basic_idxs_lifted]
copy_stack_trace([basic_subtensor, adv_subtensor], x_indexed)

x_after_index_lift = expand_dims(x_indexed, dropped_dims)
x_after_adv_idx = adv_subtensor.owner.op(x_after_index_lift, *adv_idxs)
copy_stack_trace([basic_subtensor, adv_subtensor], x_after_adv_idx)

new_out = squeeze(x_after_adv_idx[basic_idxs_kept], dropped_dims)
return [new_out]


# Rewrite will only be included if tagged by name
r = local_subtensor_of_adv_subtensor
optdb["canonicalize"].register(r.__name__, r, use_db_name_as_tag=False)
optdb["specialize"].register(r.__name__, r, use_db_name_as_tag=False)
del r
46 changes: 45 additions & 1 deletion tests/tensor/rewriting/test_subtensor_lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
)
from pytensor.tensor.shape import SpecifyShape, Unbroadcast, _shape
from pytensor.tensor.special import softmax
from pytensor.tensor.subtensor import Subtensor
from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor


NO_OPTIMIZATION_MODE = Mode(linker="py", optimizer=None)
Expand Down Expand Up @@ -756,3 +756,47 @@ def __eq__(self, other):
x = shape(Variable(MyType(), None, None))[0]

assert not local_subtensor_shape_constant.transform(None, x.owner)


@pytest.mark.parametrize(
"original_fn, supported",
[
(lambda x: x[:, [0, 1]][0], True),
(lambda x: x[:, [0, 1], [0, 0]][1:], True),
(lambda x: x[:, [[0, 1], [0, 0]]][1:], True),
# Not supported, basic indexing on advanced indexing dim
(lambda x: x[[0, 1]][0], False),
# Not implemented, basic indexing on the right of advanced indexing
(lambda x: x[[0, 1]][:, 0], False),
# Not implemented, complex flavors of advanced indexing
(lambda x: x[:, None, [0, 1]][0], False),
(lambda x: x[:, 5:, [0, 1]][0], False),
(lambda x: x[:, :, np.array([True, False, False])][0], False),
(lambda x: x[[0, 1], :, [0, 1]][:, 0], False),
],
)
def test_local_subtensor_of_adv_subtensor(original_fn, supported):
rng = np.random.default_rng(257)
x = pt.tensor3("x", shape=(7, 5, 3))
x_test = rng.normal(size=x.type.shape)

out = original_fn(x)
opt_out = rewrite_graph(
out, include=("canonicalize", "local_subtensor_of_adv_subtensor")
)
# The graphs generated are too complicated to assert
# We simply check that the happens before the advanced subtensor
toposort = FunctionGraph(outputs=[opt_out], clone=False).toposort()
[idx_subtensor] = [
i for i, node in enumerate(toposort) if isinstance(node.op, Subtensor)
]
[idx_adv_subtensor] = [
i for i, node in enumerate(toposort) if isinstance(node.op, AdvancedSubtensor)
]
swapped = idx_subtensor < idx_adv_subtensor
correct = swapped if supported else not swapped
assert correct, debugprint(opt_out, print_type=True)
np.testing.assert_allclose(
opt_out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
)

0 comments on commit d1b5784

Please sign in to comment.